diff --git a/CMakeLists.txt b/CMakeLists.txt index 707bc603..c1e3a940 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,6 +54,13 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cp add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0) +# Define GGML version variables for shared library SOVERSION +# These are required by ggml/src/CMakeLists.txt for proper library versioning +set(GGML_VERSION_MAJOR 0) +set(GGML_VERSION_MINOR 0) +set(GGML_VERSION_PATCH 0) +set(GGML_VERSION "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") + set(GGML_CPU ON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src) set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE) diff --git a/Makefile.sync b/Makefile.sync index 2966d43f..c1c24f2f 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -1,6 +1,6 @@ UPSTREAM=https://github.com/ggml-org/llama.cpp.git WORKDIR=llama/vendor -FETCH_HEAD=17f7f4baad8b3a716ee139da7bb56ae984e8c0fa +FETCH_HEAD=ec98e2002 .PHONY: help help: diff --git a/api/types.go b/api/types.go index d8467629..63b89897 100644 --- a/api/types.go +++ b/api/types.go @@ -283,11 +283,12 @@ func (pt PropertyType) String() string { } type ToolProperty struct { - AnyOf []ToolProperty `json:"anyOf,omitempty"` - Type PropertyType `json:"type,omitempty"` - Items any `json:"items,omitempty"` - Description string `json:"description,omitempty"` - Enum []any `json:"enum,omitempty"` + AnyOf []ToolProperty `json:"anyOf,omitempty"` + Type PropertyType `json:"type,omitempty"` + Items any `json:"items,omitempty"` + Description string `json:"description,omitempty"` + Enum []any `json:"enum,omitempty"` + Properties map[string]ToolProperty `json:"properties,omitempty"` } // ToTypeScriptType converts a ToolProperty to a TypeScript type string @@ -553,6 +554,9 @@ type CreateRequest struct { Renderer string `json:"renderer,omitempty"` Parser string `json:"parser,omitempty"` + // Requires is the minimum version of Ollama required by the model. + Requires string `json:"requires,omitempty"` + // Info is a map of additional information for the model Info map[string]any `json:"info,omitempty"` @@ -603,6 +607,7 @@ type ShowResponse struct { Tensors []Tensor `json:"tensors,omitempty"` Capabilities []model.Capability `json:"capabilities,omitempty"` ModifiedAt time.Time `json:"modified_at,omitempty"` + Requires string `json:"requires,omitempty"` } // CopyRequest is the request passed to [Client.Copy]. diff --git a/api/types_test.go b/api/types_test.go index 187012db..da1581f4 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -504,6 +504,107 @@ func TestThinking_UnmarshalJSON(t *testing.T) { } } +func TestToolPropertyNestedProperties(t *testing.T) { + tests := []struct { + name string + input string + expected ToolProperty + }{ + { + name: "nested object properties", + input: `{ + "type": "object", + "description": "Location details", + "properties": { + "address": { + "type": "string", + "description": "Street address" + }, + "city": { + "type": "string", + "description": "City name" + } + } + }`, + expected: ToolProperty{ + Type: PropertyType{"object"}, + Description: "Location details", + Properties: map[string]ToolProperty{ + "address": { + Type: PropertyType{"string"}, + Description: "Street address", + }, + "city": { + Type: PropertyType{"string"}, + Description: "City name", + }, + }, + }, + }, + { + name: "deeply nested properties", + input: `{ + "type": "object", + "description": "Event", + "properties": { + "location": { + "type": "object", + "description": "Location", + "properties": { + "coordinates": { + "type": "object", + "description": "GPS coordinates", + "properties": { + "lat": {"type": "number", "description": "Latitude"}, + "lng": {"type": "number", "description": "Longitude"} + } + } + } + } + } + }`, + expected: ToolProperty{ + Type: PropertyType{"object"}, + Description: "Event", + Properties: map[string]ToolProperty{ + "location": { + Type: PropertyType{"object"}, + Description: "Location", + Properties: map[string]ToolProperty{ + "coordinates": { + Type: PropertyType{"object"}, + Description: "GPS coordinates", + Properties: map[string]ToolProperty{ + "lat": {Type: PropertyType{"number"}, Description: "Latitude"}, + "lng": {Type: PropertyType{"number"}, Description: "Longitude"}, + }, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var prop ToolProperty + err := json.Unmarshal([]byte(tt.input), &prop) + require.NoError(t, err) + assert.Equal(t, tt.expected, prop) + + // Round-trip test: marshal and unmarshal again + data, err := json.Marshal(prop) + require.NoError(t, err) + + var prop2 ToolProperty + err = json.Unmarshal(data, &prop2) + require.NoError(t, err) + assert.Equal(t, tt.expected, prop2) + }) + } +} + func TestToolFunctionParameters_String(t *testing.T) { tests := []struct { name string diff --git a/app/ui/ui.go b/app/ui/ui.go index 5a64705d..26de7142 100644 --- a/app/ui/ui.go +++ b/app/ui/ui.go @@ -12,13 +12,13 @@ import ( "log/slog" "net/http" "net/http/httputil" - "net/url" "os" "runtime" "runtime/debug" "slices" "strconv" "strings" + "sync" "time" "github.com/google/uuid" @@ -117,40 +117,66 @@ func (s *Server) log() *slog.Logger { // ollamaProxy creates a reverse proxy handler to the Ollama server func (s *Server) ollamaProxy() http.Handler { - ollamaHost := os.Getenv("OLLAMA_HOST") - if ollamaHost == "" { - ollamaHost = "http://127.0.0.1:11434" - } + var ( + proxy http.Handler + proxyMu sync.Mutex + ) - if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") { - ollamaHost = "http://" + ollamaHost - } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxyMu.Lock() + p := proxy + proxyMu.Unlock() - target, err := url.Parse(ollamaHost) - if err != nil { - s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "failed to configure proxy", http.StatusInternalServerError) - }) - } + if p == nil { + proxyMu.Lock() + if proxy == nil { + var err error + for i := range 2 { + if i > 0 { + s.log().Warn("ollama server not ready, retrying", "attempt", i+1) + time.Sleep(1 * time.Second) + } - s.log().Info("configuring ollama proxy", "target", target.String()) + err = WaitForServer(context.Background(), 10*time.Second) + if err == nil { + break + } + } - proxy := httputil.NewSingleHostReverseProxy(target) + if err != nil { + proxyMu.Unlock() + s.log().Error("ollama server not ready after retries", "error", err) + http.Error(w, "Ollama server is not ready", http.StatusServiceUnavailable) + return + } - originalDirector := proxy.Director - proxy.Director = func(req *http.Request) { - originalDirector(req) - req.Host = target.Host - s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host) - } + target := envconfig.Host() + s.log().Info("configuring ollama proxy", "target", target.String()) - proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String()) - http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway) - } + newProxy := httputil.NewSingleHostReverseProxy(target) - return proxy + originalDirector := newProxy.Director + newProxy.Director = func(req *http.Request) { + originalDirector(req) + req.Host = target.Host + s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host) + } + + newProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String()) + http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway) + } + + proxy = newProxy + p = newProxy + } else { + p = proxy + } + proxyMu.Unlock() + } + + p.ServeHTTP(w, r) + }) } type errHandlerFunc func(http.ResponseWriter, *http.Request) error diff --git a/cmd/cmd.go b/cmd/cmd.go index d77bb2c5..35074ad2 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -943,6 +943,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error { rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize}) } rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel}) + if resp.Requires != "" { + rows = append(rows, []string{"", "requires", resp.Requires}) + } return }) diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 1c9d1994..7dc3d009 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -291,6 +291,31 @@ Weigh anchor! t.Errorf("unexpected output (-want +got):\n%s", diff) } }) + + t.Run("min version", func(t *testing.T) { + var b bytes.Buffer + if err := showInfo(&api.ShowResponse{ + Details: api.ModelDetails{ + Family: "test", + ParameterSize: "7B", + QuantizationLevel: "FP16", + }, + Requires: "0.14.0", + }, false, &b); err != nil { + t.Fatal(err) + } + + expect := ` Model + architecture test + parameters 7B + quantization FP16 + requires 0.14.0 + +` + if diff := cmp.Diff(expect, b.String()); diff != "" { + t.Errorf("unexpected output (-want +got):\n%s", diff) + } + }) } func TestDeleteHandler(t *testing.T) { diff --git a/convert/convert.go b/convert/convert.go index f0846795..a6d28668 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -202,6 +202,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error { conv = &qwen25VLModel{} case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration": conv = &qwen3VLModel{} + case "Olmo3ForCausalLM": + conv = &olmoModel{} case "BertModel": conv = &bertModel{} case "NomicBertModel", "NomicBertMoEModel": diff --git a/convert/convert_olmo.go b/convert/convert_olmo.go new file mode 100644 index 00000000..f75c6847 --- /dev/null +++ b/convert/convert_olmo.go @@ -0,0 +1,117 @@ +package convert + +import ( + "cmp" + + "github.com/ollama/ollama/fs/ggml" +) + +type ropeScaling struct { + Factor float32 `json:"factor"` + OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"` + AttentionFactor float32 `json:"attention_factor"` + BetaFast float32 `json:"beta_fast"` + BetaSlow float32 `json:"beta_slow"` + RopeType string `json:"rope_type"` + ExtrapolationFactor float32 `json:"extrapolation_factor"` +} + +type olmoModel struct { + ModelParameters + + HiddenSize uint32 `json:"hidden_size"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + RMSNormEPS float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + RopeScaling *ropeScaling `json:"rope_scaling"` + SlidingWindow uint32 `json:"sliding_window"` + LayerTypes []string `json:"layer_types"` +} + +var _ ModelConverter = (*olmoModel)(nil) + +func (p *olmoModel) KV(t *Tokenizer) ggml.KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "olmo3" + kv["olmo3.block_count"] = p.NumHiddenLayers + kv["olmo3.context_length"] = p.MaxPositionEmbeddings + kv["olmo3.embedding_length"] = p.HiddenSize + kv["olmo3.feed_forward_length"] = p.IntermediateSize + kv["olmo3.attention.head_count"] = p.NumAttentionHeads + kv["olmo3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads) + + if p.RopeTheta > 0 { + kv["olmo3.rope.freq_base"] = p.RopeTheta + } + + if p.RopeScaling != nil { + if p.RopeScaling.Factor > 0 { + kv["olmo3.rope.scaling.factor"] = p.RopeScaling.Factor + } + if p.RopeScaling.OriginalMaxPositionEmbeds > 0 { + kv["olmo3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds + } + if p.RopeScaling.AttentionFactor > 0 { + kv["olmo3.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor + } + if p.RopeScaling.RopeType != "" { + kv["olmo3.rope.scaling.type"] = p.RopeScaling.RopeType + } + } + + if p.RMSNormEPS > 0 { + kv["olmo3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS + } + + if p.SlidingWindow > 0 { + kv["olmo3.attention.sliding_window"] = p.SlidingWindow + } + + if len(p.LayerTypes) > 0 { + slidingPattern := make([]bool, len(p.LayerTypes)) + for i, layerType := range p.LayerTypes { + slidingPattern[i] = (layerType == "sliding_attention") + } + kv["olmo3.attention.sliding_window_pattern"] = slidingPattern + } + + return kv +} + +func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor { + out := make([]*ggml.Tensor, 0, len(ts)) + for _, t := range ts { + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + + return out +} + +func (p *olmoModel) Replacements() []string { + return []string{ + "lm_head", "output", + "model.embed_tokens", "token_embd", + "model.layers", "blk", + "model.norm", "output_norm", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_output", + "self_attn.q_norm", "attn_q_norm", + "self_attn.k_norm", "attn_k_norm", + "post_attention_layernorm", "post_attention_norm", + "post_feedforward_layernorm", "post_ffw_norm", + "mlp.gate_proj", "ffn_gate", + "mlp.down_proj", "ffn_down", + "mlp.up_proj", "ffn_up", + } +} diff --git a/convert/tokenizer_spm.go b/convert/tokenizer_spm.go index 340c3d58..51efb35e 100644 --- a/convert/tokenizer_spm.go +++ b/convert/tokenizer_spm.go @@ -49,7 +49,8 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) { tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL) // temporary fix to handle gemma3 broken configs - if slices.Contains([]string{"", ""}, piece.GetPiece()) { + // TODO(parthsareen): allow reading of tokenizer.json to allow managing special tokens when using spm + if slices.Contains([]string{"", "", "", "", "", "", "", "", ""}, piece.GetPiece()) { tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL) } diff --git a/docs/modelfile.mdx b/docs/modelfile.mdx index a3eca207..ce91bbf6 100644 --- a/docs/modelfile.mdx +++ b/docs/modelfile.mdx @@ -41,6 +41,7 @@ INSTRUCTION arguments | [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. | | [`LICENSE`](#license) | Specifies the legal license. | | [`MESSAGE`](#message) | Specify message history. | +| [`REQUIRES`](#requires) | Specify the minimum version of Ollama required by the model. | ## Examples @@ -248,6 +249,16 @@ MESSAGE user Is Ontario in Canada? MESSAGE assistant yes ``` +### REQUIRES + +The `REQUIRES` instruction allows you to specify the minimum version of Ollama required by the model. + +``` +REQUIRES +``` + +The version should be a valid Ollama version (e.g. 0.14.0). + ## Notes - the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments. diff --git a/envconfig/config.go b/envconfig/config.go index c0b2e2f0..238e5e6e 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -199,7 +199,7 @@ var ( // MultiUserCache optimizes prompt caching for multi-user scenarios MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE") // Enable the new Ollama engine - NewEngine = BoolWithDefault("OLLAMA_NEW_ENGINE") + NewEngine = Bool("OLLAMA_NEW_ENGINE") // ContextLength sets the default context length ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096) // Auth enables authentication between the Ollama client and server @@ -291,7 +291,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, - "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(true), "Enable the new Ollama engine"}, + "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, "OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"}, // Informational diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 691ea32b..44a48511 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -241,18 +241,20 @@ func (kv KV) Bools(key string, defaultValue ...[]bool) []bool { func (kv KV) OllamaEngineRequired() bool { return slices.Contains([]string{ + "bert", + "deepseek2", + "deepseekocr", "gemma3", "gemma3n", "gptoss", "gpt-oss", "llama4", "mistral3", "mllama", + "nomic-bert", + "olmo3", "qwen25vl", "qwen3", "qwen3moe", "qwen3vl", "qwen3vlmoe", - "deepseekocr", - "deepseek2", - "nomic-bert", }, kv.Architecture()) } @@ -838,9 +840,11 @@ func (f GGML) SupportsFlashAttention() bool { // FlashAttention checks if the model should enable flash attention func (f GGML) FlashAttention() bool { return slices.Contains([]string{ + "bert", "gemma3", "gptoss", "gpt-oss", "mistral3", + "olmo3", "qwen3", "qwen3moe", "qwen3vl", "qwen3vlmoe", }, f.KV().String("general.architecture")) diff --git a/go.mod b/go.mod index 6e7dc1d6..f7c9ff29 100644 --- a/go.mod +++ b/go.mod @@ -15,8 +15,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.9.0 github.com/x448/float16 v0.8.4 - golang.org/x/sync v0.12.0 - golang.org/x/sys v0.36.0 + golang.org/x/sync v0.17.0 + golang.org/x/sys v0.37.0 ) require ( @@ -29,7 +29,8 @@ require ( github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/tkrajina/typescriptify-golang-structs v0.2.0 golang.org/x/image v0.22.0 - golang.org/x/tools v0.30.0 + golang.org/x/mod v0.30.0 + golang.org/x/tools v0.38.0 gonum.org/v1/gonum v0.15.0 ) @@ -76,11 +77,11 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.8.0 // indirect - golang.org/x/crypto v0.36.0 + golang.org/x/crypto v0.43.0 golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect - golang.org/x/net v0.38.0 // indirect - golang.org/x/term v0.30.0 - golang.org/x/text v0.23.0 + golang.org/x/net v0.46.0 // indirect + golang.org/x/term v0.36.0 + golang.org/x/text v0.30.0 google.golang.org/protobuf v1.34.1 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 464cd6fc..936c040a 100644 --- a/go.sum +++ b/go.sum @@ -224,8 +224,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= +golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -255,6 +255,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -267,8 +269,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= +golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -278,8 +280,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -295,17 +297,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= -golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= -golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q= +golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= +golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -319,8 +321,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= -golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= +golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= +golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/kvcache/causal.go b/kvcache/causal.go index 15a4cdea..e04f828e 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -140,10 +140,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity c.config.CachePadding = 1 } - if c.config.MaskBatchPadding == 0 { - c.config.MaskBatchPadding = 1 - } - if c.config.MaskDType == ml.DTypeOther { c.config.MaskDType = ml.DTypeF32 } @@ -364,15 +360,12 @@ func roundUp(length, pad int) int { // token in the history should apply. This is based on both the sequence and causality (the // position of the history is not ahead of the token in the batch). func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { - // Align and pad the two dimensions as required by the backend - batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) - c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 length := c.curCellRange.max - c.curCellRange.min + 1 - mask := make([]float32, batchSize*length) + mask := make([]float32, c.curBatchSize*length) for i := range c.curBatchSize { enabled := !slices.Contains(c.opts.Except, i) @@ -386,13 +379,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { } } - // Mask out any padding tokens we added. For padding that we added to the cache history, this - // has already been masked out because the sequence doesn't match. - for i := c.curBatchSize * length; i < len(mask); i++ { - mask[i] = float32(math.Inf(-1)) - } - - maskTensor := ctx.Input().FromFloats(mask, length, batchSize) + maskTensor := ctx.Input().FromFloats(mask, length, c.curBatchSize) if c.config.MaskDType != ml.DTypeF32 { maskTensor = maskTensor.Cast(ctx, c.config.MaskDType) diff --git a/llama/build-info.cpp b/llama/build-info.cpp index 5666fbc4..b37cd25e 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,4 +1,4 @@ int LLAMA_BUILD_NUMBER = 0; -char const *LLAMA_COMMIT = "17f7f4baad8b3a716ee139da7bb56ae984e8c0fa"; +char const *LLAMA_COMMIT = "ec98e2002"; char const *LLAMA_COMPILER = ""; char const *LLAMA_BUILD_TARGET = ""; diff --git a/llama/llama.cpp/.rsync-filter b/llama/llama.cpp/.rsync-filter index df75ca65..7987be1c 100644 --- a/llama/llama.cpp/.rsync-filter +++ b/llama/llama.cpp/.rsync-filter @@ -17,6 +17,9 @@ include /tools/mtmd/clip.cpp include /tools/mtmd/mtmd.cpp include /tools/mtmd/mtmd-audio.cpp include /tools/mtmd/mtmd-helper.cpp +include /tools/mtmd/models/ +include /tools/mtmd/models/*.h +include /tools/mtmd/models/*.cpp include /src/ include /src/llama.* include /src/llama-*.* diff --git a/llama/llama.cpp/common/common.cpp b/llama/llama.cpp/common/common.cpp index 0497f90a..5a8cf524 100644 --- a/llama/llama.cpp/common/common.cpp +++ b/llama/llama.cpp/common/common.cpp @@ -1013,31 +1013,40 @@ bool tty_can_use_colors() { // Model utils // -static inline void common_init_sampler_from_model( +// TODO: move to common/sampling +static void common_init_sampler_from_model( const llama_model * model, common_params_sampling & sparams) { const uint64_t config = sparams.user_sampling_config; auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) { - if (config & user_config) return; + if (config & user_config) { + return; + } char buf[64] = {0}; if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { char * end = nullptr; int32_t v = strtol(buf, &end, 10); - if (end && end != buf) dst = v; + if (end && end != buf) { + dst = v; + } } }; auto get_float = [&](const char * key, float & dst, uint64_t user_config) { - if (config & user_config) return; + if (config & user_config) { + return; + } char buf[128] = {0}; if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { char * end = nullptr; float v = strtof(buf, &end); - if (end && end != buf) dst = v; + if (end && end != buf) { + dst = v; + } } }; @@ -1065,31 +1074,125 @@ static inline void common_init_sampler_from_model( get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA); } -struct common_init_result common_init_from_params(common_params & params) { - common_init_result iparams; +struct common_init_result::impl { + impl() = default; + ~impl() = default; + + llama_model_ptr model; + llama_context_ptr context; + + std::vector lora; + + std::vector samplers; +}; + +common_init_result::common_init_result(common_params & params) : + pimpl(new impl{}) { auto mparams = common_model_params_to_llama(params); + auto cparams = common_context_params_to_llama(params); + + if (params.fit_params) { + LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__); + llama_params_fit(params.model.path.c_str(), &mparams, &cparams, + params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx, + params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); + } llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); if (model == NULL) { - LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", - __func__, params.model.path.c_str()); - return iparams; + return; } - common_init_sampler_from_model(model, params.sampling); + pimpl->model.reset(model); const llama_vocab * vocab = llama_model_get_vocab(model); - auto cparams = common_context_params_to_llama(params); + // updates params.sampling + // TODO: fix naming + common_init_sampler_from_model(model, params.sampling); + + if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); + params.sampling.ignore_eos = false; + } + + // initialize once + for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { + LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY); + params.sampling.logit_bias_eog.push_back({i, -INFINITY}); + } + } + + if (params.sampling.ignore_eos) { + // add EOG biases to the active set of logit biases + params.sampling.logit_bias.insert( + params.sampling.logit_bias.end(), + params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end()); + } + + //if (params.sampling.penalty_last_n == -1) { + // LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + // params.sampling.penalty_last_n = llama_n_ctx(lctx); + //} + + //if (params.sampling.dry_penalty_last_n == -1) { + // LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); + //} + + pimpl->samplers.resize(cparams.n_seq_max); + + for (int i = 0; i < (int) cparams.n_seq_max; ++i) { + pimpl->samplers[i].reset(common_sampler_init(model, params.sampling)); + } llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { - LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", - __func__, params.model.path.c_str()); - llama_model_free(model); - return iparams; + LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); + return; } + pimpl->context.reset(lctx); +} + +llama_model * common_init_result::model() { + return pimpl->model.get(); +} + +llama_context * common_init_result::context() { + return pimpl->context.get(); +} + +common_sampler * common_init_result::sampler(llama_seq_id seq_id) { + return pimpl->samplers[seq_id].get(); +} + +std::vector & common_init_result::lora() { + return pimpl->lora; +} + +void common_init_result::free_context() { + pimpl->context.reset(); +} + +common_init_result_ptr common_init_from_params(common_params & params) { + common_init_result_ptr res(new common_init_result(params)); + + llama_model * model = res->model(); + if (model == NULL) { + LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str()); + return res; + } + + llama_context * lctx = res->context(); + if (lctx == NULL) { + LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); + return res; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) { LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); params.ctx_shift = false; @@ -1101,10 +1204,7 @@ struct common_init_result common_init_from_params(common_params & params) { const auto cvec = common_control_vector_load(params.control_vectors); if (cvec.n_embd == -1) { - llama_free(lctx); - llama_model_free(model); - - return iparams; + return res; } int err = llama_apply_adapter_cvec( @@ -1115,10 +1215,7 @@ struct common_init_result common_init_from_params(common_params & params) { params.control_vector_layer_start, params.control_vector_layer_end); if (err) { - llama_free(lctx); - llama_model_free(model); - - return iparams; + return res; } } @@ -1142,10 +1239,7 @@ struct common_init_result common_init_from_params(common_params & params) { } if (!ok) { - llama_free(lctx); - llama_model_free(model); - - return iparams; + return res; } } @@ -1155,9 +1249,7 @@ struct common_init_result common_init_from_params(common_params & params) { lora.reset(llama_adapter_lora_init(model, la.path.c_str())); if (lora == nullptr) { LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); - llama_free(lctx); - llama_model_free(model); - return iparams; + return res; } char buf[1024]; @@ -1166,43 +1258,13 @@ struct common_init_result common_init_from_params(common_params & params) { la.task_name = buf; llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); la.prompt_prefix = buf; - iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters + res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters } if (!params.lora_init_without_apply) { common_set_adapter_lora(lctx, params.lora_adapters); } - if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); - params.sampling.ignore_eos = false; - } - - // initialize once - for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { - if (llama_vocab_is_eog(vocab, i)) { - LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); - params.sampling.logit_bias_eog.push_back({i, -INFINITY}); - } - } - - if (params.sampling.ignore_eos) { - // add EOG biases to the active set of logit biases - params.sampling.logit_bias.insert( - params.sampling.logit_bias.end(), - params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end()); - } - - if (params.sampling.penalty_last_n == -1) { - LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); - params.sampling.penalty_last_n = llama_n_ctx(lctx); - } - - if (params.sampling.dry_penalty_last_n == -1) { - LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); - params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); - } - if (params.warmup) { LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); @@ -1241,12 +1303,11 @@ struct common_init_result common_init_from_params(common_params & params) { llama_set_warmup(lctx, false); } - iparams.model.reset(model); - iparams.context.reset(lctx); - - return iparams; + return res; } +common_init_result::~common_init_result() = default; + std::string get_model_endpoint() { const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. @@ -1255,7 +1316,9 @@ std::string get_model_endpoint() { std::string model_endpoint = "https://huggingface.co/"; if (endpoint_env) { model_endpoint = endpoint_env; - if (model_endpoint.back() != '/') model_endpoint += '/'; + if (model_endpoint.back() != '/') { + model_endpoint += '/'; + } } return model_endpoint; } diff --git a/llama/llama.cpp/common/common.h b/llama/llama.cpp/common/common.h index d28e4899..d7074484 100644 --- a/llama/llama.cpp/common/common.h +++ b/llama/llama.cpp/common/common.h @@ -82,7 +82,8 @@ int32_t cpu_get_num_math(); enum llama_example { LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_SPECULATIVE, - LLAMA_EXAMPLE_MAIN, + LLAMA_EXAMPLE_COMPLETION, + LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL, @@ -98,6 +99,7 @@ enum llama_example { LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_DIFFUSION, LLAMA_EXAMPLE_FINETUNE, + LLAMA_EXAMPLE_FIT_PARAMS, LLAMA_EXAMPLE_COUNT, }; @@ -194,7 +196,6 @@ struct common_params_sampling { std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY - std::vector samplers = { COMMON_SAMPLER_TYPE_PENALTIES, COMMON_SAMPLER_TYPE_DRY, @@ -215,6 +216,10 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens + bool has_logit_bias() const { + return !logit_bias.empty(); + } + // print the parameters into a string std::string print() const; }; @@ -302,8 +307,8 @@ struct lr_opt { struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata); struct common_params { - int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 4096; // context size + int32_t n_predict = -1; // max. number of new tokens to predict, -1 == no limit + int32_t n_ctx = 0; // context size, 0 == context the model was trained with int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt @@ -324,9 +329,12 @@ struct common_params { // offload params std::vector devices; // devices to use for offloading - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + bool fit_params = true; // whether to fit unset model/context parameters to free device memory + size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory + int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs @@ -406,6 +414,7 @@ struct common_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool no_perf = false; // disable performance metrics + bool show_timings = true; // show timing information on CLI bool ctx_shift = false; // context shift on infinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool kv_unified = false; // enable unified KV cache @@ -462,7 +471,7 @@ struct common_params { std::string public_path = ""; // NOLINT std::string api_prefix = ""; // NOLINT std::string chat_template = ""; // NOLINT - bool use_jinja = false; // NOLINT + bool use_jinja = true; // NOLINT bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int reasoning_budget = -1; @@ -482,9 +491,10 @@ struct common_params { bool endpoint_metrics = false; // router server configs - std::string models_dir = ""; // directory containing models for the router server - int models_max = 4; // maximum number of models to load simultaneously - bool models_autoload = true; // automatically load models when requested via the router server + std::string models_dir = ""; // directory containing models for the router server + std::string models_preset = ""; // directory containing model presets for the router server + int models_max = 4; // maximum number of models to load simultaneously + bool models_autoload = true; // automatically load models when requested via the router server bool log_json = false; @@ -666,15 +676,29 @@ bool tty_can_use_colors(); // Model utils // -// note: defines object's lifetime -struct common_init_result { - llama_model_ptr model; - llama_context_ptr context; +struct common_sampler; - std::vector lora; +// note: defines the model, context, samplers, ets. lifetimes +struct common_init_result { + common_init_result(common_params & params); + ~common_init_result(); + + llama_model * model(); + llama_context * context(); + common_sampler * sampler(llama_seq_id seq_id); + + std::vector & lora(); + + void free_context(); + +private: + struct impl; + std::unique_ptr pimpl; }; -struct common_init_result common_init_from_params(common_params & params); +using common_init_result_ptr = std::unique_ptr; + +common_init_result_ptr common_init_from_params(common_params & params); struct llama_model_params common_model_params_to_llama ( common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params); diff --git a/llama/llama.cpp/common/json-schema-to-grammar.cpp b/llama/llama.cpp/common/json-schema-to-grammar.cpp index 6be55282..acf00e2d 100644 --- a/llama/llama.cpp/common/json-schema-to-grammar.cpp +++ b/llama/llama.cpp/common/json-schema-to-grammar.cpp @@ -305,8 +305,9 @@ static std::string format_literal(const std::string & literal) { std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); } -class SchemaConverter { +class common_schema_converter { private: + friend class common_schema_info; friend std::string build_grammar(const std::function & cb, const common_grammar_options & options); std::function _fetch_json; bool _dotall; @@ -729,7 +730,7 @@ private: } public: - SchemaConverter( + common_schema_converter( const std::function & fetch_json, bool dotall) : _fetch_json(fetch_json), _dotall(dotall) @@ -990,6 +991,134 @@ public: } }; +// common_schema_info implementation (pimpl) + +common_schema_info::common_schema_info() + : impl_(std::make_unique( + [](const std::string &) { return json(); }, + false)) {} + +common_schema_info::~common_schema_info() = default; + +common_schema_info::common_schema_info(common_schema_info &&) noexcept = default; +common_schema_info & common_schema_info::operator=(common_schema_info &&) noexcept = default; + +void common_schema_info::resolve_refs(nlohmann::ordered_json & schema) { + impl_->resolve_refs(schema, ""); +} + +// Determines if a JSON schema can resolve to a string type through any path. +// Some models emit raw string values rather than JSON-encoded strings for string parameters. +// If any branch of the schema (via oneOf, anyOf, $ref, etc.) permits a string, this returns +// true, allowing callers to handle the value as a raw string for simplicity. +bool common_schema_info::resolves_to_string(const nlohmann::ordered_json & schema) { + std::unordered_set visited_refs; + + std::function check = [&](const json & s) -> bool { + if (!s.is_object()) { + return false; + } + + // Handle $ref + if (s.contains("$ref")) { + const std::string & ref = s["$ref"]; + if (visited_refs.find(ref) != visited_refs.end()) { + // Circular reference, assume not a string to be safe + return false; + } + visited_refs.insert(ref); + auto it = impl_->_refs.find(ref); + if (it != impl_->_refs.end()) { + return check(it->second); + } + return false; + } + + // Check type field + if (s.contains("type")) { + const json & schema_type = s["type"]; + if (schema_type.is_string()) { + if (schema_type == "string") { + return true; + } + } else if (schema_type.is_array()) { + // Type can be an array like ["string", "null"] + for (const auto & t : schema_type) { + if (t == "string") { + return true; + } + } + } + } + + // Check oneOf/anyOf - if any alternative can be a string + if (s.contains("oneOf")) { + for (const auto & alt : s["oneOf"]) { + if (check(alt)) { + return true; + } + } + } + if (s.contains("anyOf")) { + for (const auto & alt : s["anyOf"]) { + if (check(alt)) { + return true; + } + } + } + + // Check allOf - all components must be compatible with string type + if (s.contains("allOf")) { + bool all_string = true; + for (const auto & component : s["allOf"]) { + if (!check(component)) { + all_string = false; + break; + } + } + if (all_string) { + return true; + } + } + + // Check const - if the constant value is a string + if (s.contains("const")) { + if (s["const"].is_string()) { + return true; + } + } + + // Check enum - if any enum value is a string + if (s.contains("enum")) { + for (const auto & val : s["enum"]) { + if (val.is_string()) { + return true; + } + } + } + + // String-specific keywords imply string type + if (s.contains("pattern") || s.contains("minLength") || s.contains("maxLength")) { + return true; + } + + // Check format - many formats imply string + if (s.contains("format")) { + const std::string & fmt = s["format"]; + if (fmt == "date" || fmt == "time" || fmt == "date-time" || + fmt == "uri" || fmt == "email" || fmt == "hostname" || + fmt == "ipv4" || fmt == "ipv6" || fmt == "uuid" || + fmt.find("uuid") == 0) { + return true; + } + } + + return false; + }; + + return check(schema); +} + std::string json_schema_to_grammar(const json & schema, bool force_gbnf) { #ifdef LLAMA_USE_LLGUIDANCE if (!force_gbnf) { @@ -1006,7 +1135,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) { } std::string build_grammar(const std::function & cb, const common_grammar_options & options) { - SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall); + common_schema_converter converter([&](const std::string &) { return json(); }, options.dotall); common_grammar_builder builder { /* .add_rule = */ [&](const std::string & name, const std::string & rule) { return converter._add_rule(name, rule); diff --git a/llama/llama.cpp/common/json-schema-to-grammar.h b/llama/llama.cpp/common/json-schema-to-grammar.h index c89ab7f9..240d6423 100644 --- a/llama/llama.cpp/common/json-schema-to-grammar.h +++ b/llama/llama.cpp/common/json-schema-to-grammar.h @@ -3,11 +3,31 @@ #include #include +#include #include std::string json_schema_to_grammar(const nlohmann::ordered_json & schema, bool force_gbnf = false); +class common_schema_converter; + +// Probes a JSON schema to extract information about its structure and type constraints. +class common_schema_info { + std::unique_ptr impl_; + + public: + common_schema_info(); + ~common_schema_info(); + + common_schema_info(const common_schema_info &) = delete; + common_schema_info & operator=(const common_schema_info &) = delete; + common_schema_info(common_schema_info &&) noexcept; + common_schema_info & operator=(common_schema_info &&) noexcept; + + void resolve_refs(nlohmann::ordered_json & schema); + bool resolves_to_string(const nlohmann::ordered_json & schema); +}; + struct common_grammar_builder { std::function add_rule; std::function add_schema; diff --git a/llama/llama.cpp/common/log.cpp b/llama/llama.cpp/common/log.cpp index 00a03f15..b17d2b62 100644 --- a/llama/llama.cpp/common/log.cpp +++ b/llama/llama.cpp/common/log.cpp @@ -420,6 +420,11 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps) { log->set_timestamps(timestamps); } +void common_log_flush(struct common_log * log) { + log->pause(); + log->resume(); +} + static int common_get_verbosity(enum ggml_log_level level) { switch (level) { case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG; diff --git a/llama/llama.cpp/common/log.h b/llama/llama.cpp/common/log.h index b24f5f00..f0f8471b 100644 --- a/llama/llama.cpp/common/log.h +++ b/llama/llama.cpp/common/log.h @@ -84,6 +84,7 @@ void common_log_set_file (struct common_log * log, const char * file); // n void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix +void common_log_flush (struct common_log * log); // flush all pending log messages // helper macros for logging // use these to avoid computing log arguments if the verbosity of the log is higher than the threshold diff --git a/llama/llama.cpp/common/sampling.cpp b/llama/llama.cpp/common/sampling.cpp index 7a6b7be1..6935d84e 100644 --- a/llama/llama.cpp/common/sampling.cpp +++ b/llama/llama.cpp/common/sampling.cpp @@ -104,9 +104,10 @@ struct ring_buffer { struct common_sampler { common_params_sampling params; - struct llama_sampler * grmr; struct llama_sampler * chain; + bool grammar; + ring_buffer prev; std::vector cur; @@ -116,7 +117,6 @@ struct common_sampler { void reset() { prev.clear(); - llama_sampler_reset(grmr); llama_sampler_reset(chain); } @@ -167,10 +167,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co lparams.no_perf = params.no_perf; - struct llama_sampler * grmr; + llama_sampler * chain = llama_sampler_chain_init(lparams); + + bool grammar = false; + std::vector samplers; + if (params.grammar.compare(0, 11, "%llguidance") == 0) { #ifdef LLAMA_USE_LLGUIDANCE - grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); + samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str())); + grammar = true; #else GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE @@ -217,30 +222,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co trigger_patterns_c.push_back(regex.c_str()); } - grmr = params.grammar_lazy - ? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", - trigger_patterns_c.data(), trigger_patterns_c.size(), - trigger_tokens.data(), trigger_tokens.size()) - : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); - if (!grmr) { - return nullptr; + if (!params.grammar.empty()) { + if (params.grammar_lazy) { + samplers.push_back( + llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", + trigger_patterns_c.data(), trigger_patterns_c.size(), + trigger_tokens.data(), trigger_tokens.size())); + } else { + samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root")); + } + + grammar = true; } } - auto * result = new common_sampler { - /* .params = */ params, - /* .grmr = */ grmr, - /* .chain = */ llama_sampler_chain_init(lparams), - /* .prev = */ ring_buffer(std::max(32, params.n_prev)), - /* .cur = */ {}, - /* .cur_p = */ {}, - }; - - llama_sampler_chain_add(result->chain, - llama_sampler_init_logit_bias( - llama_vocab_n_tokens(vocab), - params.logit_bias.size(), - params.logit_bias.data())); + if (params.has_logit_bias()) { + samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data())); + } if (params.mirostat == 0) { for (const auto & cnstr : params.samplers) { @@ -253,58 +251,70 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co c_breakers.push_back(str.c_str()); } - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + samplers.push_back(llama_sampler_init_top_k (params.top_k)); break; case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); + samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma)); break; case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); break; case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); + samplers.push_back(llama_sampler_init_infill (vocab)); break; case COMMON_SAMPLER_TYPE_PENALTIES: - llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); break; default: GGML_ASSERT(false && "unknown sampler type"); } } - llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); + + samplers.push_back(llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { - llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + samplers.push_back(llama_sampler_init_temp(params.temp)); + samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); } else if (params.mirostat == 2) { - llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); + samplers.push_back(llama_sampler_init_temp(params.temp)); + samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); } else { GGML_ASSERT(false && "unknown mirostat version"); } + for (auto * smpl : samplers) { + llama_sampler_chain_add(chain, smpl); + } + + auto * result = new common_sampler { + /* .params = */ params, + /* .chain = */ chain, + /* .grammar = */ grammar, + /* .prev = */ ring_buffer(std::max(32, params.n_prev)), + /* .cur = */ {}, + /* .cur_p = */ {}, + }; + return result; } void common_sampler_free(struct common_sampler * gsmpl) { if (gsmpl) { - llama_sampler_free(gsmpl->grmr); - llama_sampler_free(gsmpl->chain); delete gsmpl; @@ -314,11 +324,24 @@ void common_sampler_free(struct common_sampler * gsmpl) { void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { const auto tm = gsmpl->tm(); - if (accept_grammar) { - llama_sampler_accept(gsmpl->grmr, token); - } + if (gsmpl->grammar) { + const int n_smpl = llama_sampler_chain_n(gsmpl->chain); - llama_sampler_accept(gsmpl->chain, token); + for (int i = 0; i < n_smpl; i++) { + auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); + + // the grammar sampler is always the first one + if (i == 0) { + if (accept_grammar) { + llama_sampler_accept(smpl, token); + } + } else { + llama_sampler_accept(smpl, token); + } + } + } else { + llama_sampler_accept(gsmpl->chain, token); + } gsmpl->prev.push_back(token); } @@ -329,12 +352,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) { struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { return new common_sampler { - /* .params = */ gsmpl->params, - /* .grmr = */ llama_sampler_clone(gsmpl->grmr), - /* .chain = */ llama_sampler_clone(gsmpl->chain), - /* .prev = */ gsmpl->prev, - /* .cur = */ gsmpl->cur, - /* .cur_p = */ gsmpl->cur_p, + /* .params = */ gsmpl->params, + /* .chain = */ llama_sampler_clone(gsmpl->chain), + /* .grammar = */ gsmpl->grammar, + /* .prev = */ gsmpl->prev, + /* .cur = */ gsmpl->cur, + /* .cur_p = */ gsmpl->cur_p, }; } @@ -383,58 +406,33 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } } -llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { +struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) { + return gsmpl->chain; +} + +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) { llama_synchronize(ctx); // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations const auto tm = gsmpl->tm(); - gsmpl->set_logits(ctx, idx); + llama_token id = LLAMA_TOKEN_NULL; - auto & grmr = gsmpl->grmr; auto & chain = gsmpl->chain; auto & cur_p = gsmpl->cur_p; // initialized by set_logits - if (grammar_first) { - llama_sampler_apply(grmr, &cur_p); - } + gsmpl->set_logits(ctx, idx); llama_sampler_apply(chain, &cur_p); GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); - const llama_token id = cur_p.data[cur_p.selected].id; + id = cur_p.data[cur_p.selected].id; - if (grammar_first) { - return id; - } - - // check if it the sampled token fits the grammar - { - llama_token_data single_token_data = { id, 1.0f, 0.0f }; - llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false }; - - llama_sampler_apply(grmr, &single_token_data_array); - - const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; - if (is_valid) { - return id; - } - } - - // resampling: - // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain - gsmpl->set_logits(ctx, idx); - - llama_sampler_apply(grmr, &cur_p); - llama_sampler_apply(chain, &cur_p); - - GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration"); - - return cur_p.data[cur_p.selected].id; + return id; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft) { GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); std::vector result; @@ -442,7 +440,7 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample size_t i = 0; for (; i < draft.size(); i++) { - const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]); common_sampler_accept(gsmpl, id, true); @@ -454,7 +452,7 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample } if (i == draft.size()) { - const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]); common_sampler_accept(gsmpl, id, true); @@ -464,13 +462,13 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample return result; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) { std::vector idxs(draft.size() + 1); for (size_t i = 0; i < idxs.size(); ++i) { idxs[i] = i; } - return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); + return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft); } uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { @@ -515,7 +513,8 @@ std::string common_sampler_print(const struct common_sampler * gsmpl) { for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); - result += std::string("-> ") + llama_sampler_name(smpl) + " "; + result += std::string("-> "); + result += std::string(llama_sampler_name(smpl)) + " "; } return result; diff --git a/llama/llama.cpp/common/sampling.h b/llama/llama.cpp/common/sampling.h index e198eecd..ace5d3d0 100644 --- a/llama/llama.cpp/common/sampling.h +++ b/llama/llama.cpp/common/sampling.h @@ -48,6 +48,8 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); // arguments can be nullptr to skip printing void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); +struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl); + // extended sampling implementation: // // - set logits @@ -55,10 +57,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam // - check if the token fits the grammar (if any) // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // -// if grammar_first is true, the grammar is applied before the samplers (slower) -// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar -// -llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx); // generalized version of common_sampler_sample // @@ -76,10 +75,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // // returns at least 1 token, up to idxs.size() // -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft); // assume idxs == [ 0, 1, 2, ..., draft.size() ] -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); @@ -107,3 +106,9 @@ std::vector common_sampler_types_from_chars(const std: llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data); + +struct common_sampler_deleter { + void operator()(common_sampler * s) { common_sampler_free(s); } +}; + +typedef std::unique_ptr common_sampler_ptr; diff --git a/llama/llama.cpp/include/llama.h b/llama/llama.cpp/include/llama.h index b52eaacf..f8629300 100644 --- a/llama/llama.cpp/include/llama.h +++ b/llama/llama.cpp/include/llama.h @@ -313,6 +313,7 @@ extern "C" { bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) bool no_host; // bypass host buffer allowing extra buffers to be used + bool no_alloc; // only load metadata and simulate memory allocations }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations @@ -466,10 +467,24 @@ extern "C" { // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); + // fits mparams and cparams to free device memory (assumes system memory is unlimited) + // returns true if the parameters could be successfully modified to fit device memory + // this function is NOT thread safe because it modifies the global llama logger state + LLAMA_API bool llama_params_fit( + const char * path_model, + struct llama_model_params * mparams, + struct llama_context_params * cparams, + float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements + struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements + size_t margin, // margin of memory to leave per device in bytes + uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use + enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log + LLAMA_API int64_t llama_time_us(void); LLAMA_API size_t llama_max_devices(void); LLAMA_API size_t llama_max_parallel_sequences(void); + LLAMA_API size_t llama_max_tensor_buft_overrides(void); LLAMA_API bool llama_supports_mmap (void); LLAMA_API bool llama_supports_mlock (void); @@ -1354,7 +1369,9 @@ extern "C" { // Set callback for all future logging events. // If this is not called, or NULL is supplied, everything is output on stderr. - LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); + // The logger state is global so these functions are NOT thread safe. + LLAMA_API void llama_log_get(ggml_log_callback * log_callback, void ** user_data); + LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); // // Performance utils diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index a5fe4f66..2ce8ffec 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -3,6 +3,7 @@ #include "llama-impl.h" #include +#include static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize @@ -75,6 +76,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, + { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, { LLM_ARCH_RWKV6, "rwkv6" }, @@ -303,2271 +305,1901 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, }; -static const std::map> LLM_TENSOR_NAMES = { - { - LLM_ARCH_CLIP, - {}, - }, - { - LLM_ARCH_LLAMA, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, - { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, - { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_ARCEE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_AFMOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, - }, - }, - { - LLM_ARCH_LLAMA4, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, - { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, - { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - }, - }, - { - LLM_ARCH_DECI, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, - { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, - { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_BAICHUAN, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_FALCON, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_GROK, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, - { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, - { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, - { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, - { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, - }, - }, - { - LLM_ARCH_GPT2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_POS_EMBD, "position_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - }, - }, - { - LLM_ARCH_GPTJ, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - }, - }, - { - LLM_ARCH_GPTNEOX, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_MPT, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output"}, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, - { LLM_TENSOR_POS_EMBD, "position_embd" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"}, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"}, - }, - }, - { - LLM_ARCH_STARCODER, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_POS_EMBD, "position_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - }, - }, - { - LLM_ARCH_REFACT, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_BERT, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, - { LLM_TENSOR_TOKEN_TYPES, "token_types" }, - { LLM_TENSOR_POS_EMBD, "position_embd" }, - { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_CLS, "cls" }, - { LLM_TENSOR_CLS_OUT, "cls.output" }, - }, - }, - { - LLM_ARCH_NOMIC_BERT, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, - { LLM_TENSOR_TOKEN_TYPES, "token_types" }, - { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_NOMIC_BERT_MOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, - { LLM_TENSOR_TOKEN_TYPES, "token_types" }, - { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_NEO_BERT, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, - { LLM_TENSOR_CLS, "cls" }, - { LLM_TENSOR_CLS_OUT, "cls.output" }, - }, - }, - { - LLM_ARCH_JINA_BERT_V2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, - { LLM_TENSOR_TOKEN_TYPES, "token_types" }, - { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, - { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_CLS, "cls" }, - }, - }, - { - LLM_ARCH_JINA_BERT_V3, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, - { LLM_TENSOR_TOKEN_TYPES, "token_types" }, - { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, - }, - }, - { - LLM_ARCH_BLOOM, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - }, - }, - { - LLM_ARCH_STABLELM, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - }, - }, - { - LLM_ARCH_QWEN, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_QWEN2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_QWEN2VL, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_QWEN2MOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - }, - }, - { - LLM_ARCH_QWEN3, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_CLS_OUT, "cls.output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_QWEN3MOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_QWEN3NEXT, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, - { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, - { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, - { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, - { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, - { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, - { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, - }, - }, - { - LLM_ARCH_QWEN3VL, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_QWEN3VLMOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_PHI2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_PHI3, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, - { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_PHIMOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, - { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_PLAMO, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_PLAMO2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, - { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, - { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, - { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, - { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, - { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, - { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, - { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, - { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, - { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, - }, - }, - { - LLM_ARCH_CODESHELL, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_ORION, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_INTERNLM2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_MINICPM, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, - { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, - { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, - { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, - }, - }, - { - LLM_ARCH_MINICPM3, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, - { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, - { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, - { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, - { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, - { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - }, - }, - { - LLM_ARCH_GEMMA, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_GEMMA2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, - }, - }, - { - LLM_ARCH_GEMMA3, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, - }, - }, - { - LLM_ARCH_GEMMA3N, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, - { LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" }, - { LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" }, - { LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" }, - { LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" }, - { LLM_TENSOR_ALTUP_PROJ, "altup_proj" }, - { LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" }, - { LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" }, - { LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" }, - { LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" }, - { LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" }, - { LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" }, - { LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" }, - { LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" }, - { LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" }, - { LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" }, - { LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" }, - }, - }, - { - LLM_ARCH_GEMMA_EMBEDDING, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_DENSE_2_OUT, "dense_2" }, - { LLM_TENSOR_DENSE_3_OUT, "dense_3" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, - }, - }, - { - LLM_ARCH_STARCODER2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_MAMBA, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, - { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, - { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, - { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, - { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, - { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, - { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, - }, - }, - { - LLM_ARCH_MAMBA2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, - { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, - { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, - { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, - { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, - { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, - { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, - }, - }, - { - LLM_ARCH_JAMBA, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, - { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, - { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, - { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, - { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, - { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, - { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, - { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, - { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, - { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_FALCON_H1, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, - { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, - { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, - { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, - { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, - { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, - { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_XVERSE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_COMMAND_R, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - }, - }, - { - LLM_ARCH_COHERE2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_DBRX, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_OLMO, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_OLMO2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_OLMOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_OPENELM, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_ARCTIC, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_DEEPSEEK, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - }, - }, - { - LLM_ARCH_DEEPSEEK2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, - { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, - { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, - { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, - { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, - { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, - { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, - }, - }, - { - LLM_ARCH_PLM, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, - { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, - { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_CHATGLM, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - }, - }, - { - LLM_ARCH_GLM4, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, - }, - }, - { - LLM_ARCH_GLM4_MOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, - // NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number) - { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, - { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, - { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, - { LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" }, - { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" }, - { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, - }, - }, - { - LLM_ARCH_BITNET, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, - }, - }, - { - LLM_ARCH_T5, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" }, - { LLM_TENSOR_DEC_ATTN_NORM, "dec.blk.%d.attn_norm" }, - { LLM_TENSOR_DEC_ATTN_Q, "dec.blk.%d.attn_q" }, - { LLM_TENSOR_DEC_ATTN_K, "dec.blk.%d.attn_k" }, - { LLM_TENSOR_DEC_ATTN_V, "dec.blk.%d.attn_v" }, - { LLM_TENSOR_DEC_ATTN_OUT, "dec.blk.%d.attn_o" }, - { LLM_TENSOR_DEC_ATTN_REL_B, "dec.blk.%d.attn_rel_b" }, - { LLM_TENSOR_DEC_CROSS_ATTN_NORM, "dec.blk.%d.cross_attn_norm" }, - { LLM_TENSOR_DEC_CROSS_ATTN_Q, "dec.blk.%d.cross_attn_q" }, - { LLM_TENSOR_DEC_CROSS_ATTN_K, "dec.blk.%d.cross_attn_k" }, - { LLM_TENSOR_DEC_CROSS_ATTN_V, "dec.blk.%d.cross_attn_v" }, - { LLM_TENSOR_DEC_CROSS_ATTN_OUT, "dec.blk.%d.cross_attn_o" }, - { LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" }, - { LLM_TENSOR_DEC_FFN_NORM, "dec.blk.%d.ffn_norm" }, - { LLM_TENSOR_DEC_FFN_GATE, "dec.blk.%d.ffn_gate" }, - { LLM_TENSOR_DEC_FFN_DOWN, "dec.blk.%d.ffn_down" }, - { LLM_TENSOR_DEC_FFN_UP, "dec.blk.%d.ffn_up" }, - { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, - { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, - { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, - { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, - { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, - { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, - { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, - { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, - { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, - { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, - { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_T5ENCODER, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, - { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, - { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, - { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, - { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, - { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, - { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, - { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, - { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, - { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, - { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_JAIS, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - }, - }, - { - LLM_ARCH_NEMOTRON, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_NEMOTRON_H, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - // mamba(2) ssm layers - { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, - { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, - { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, - { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, - { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, - { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, - { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, - // attention layers - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - // dense FFN - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_EXAONE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_EXAONE4, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, - } - }, - { - LLM_ARCH_RWKV6, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, - { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, - { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, - { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, - { LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" }, - { LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" }, - { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" }, - { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" }, - { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" }, - { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, - { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, - { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, - { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, - { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, - { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, - { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, - { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, - { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, - { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, - { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, - { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" }, - { LLM_TENSOR_CHANNEL_MIX_LERP_R, "blk.%d.channel_mix_lerp_r" }, - { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" }, - { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" }, - { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, - }, - }, - { - LLM_ARCH_RWKV6QWEN2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, - { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, - { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, - { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, - { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, - { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, - { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, - { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, - { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, - { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, - { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, - { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, - { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_RWKV7, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, - { LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" }, - { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, - { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, - { LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" }, - { LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" }, - { LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" }, - { LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" }, - { LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" }, - { LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" }, - { LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" }, - { LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" }, - { LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" }, - { LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" }, - { LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" }, - { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, - { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, - { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, - { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, - { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, - { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, - { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" }, - { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" }, - { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" }, - }, - }, - { - LLM_ARCH_ARWKV7, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" }, - { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, - { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, - { LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" }, - { LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" }, - { LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" }, - { LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" }, - { LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" }, - { LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" }, - { LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" }, - { LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" }, - { LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" }, - { LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" }, - { LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" }, - { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, - { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, - { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, - { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, - { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, - { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_GRANITE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_GRANITE_MOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - }, - }, - { - LLM_ARCH_GRANITE_HYBRID, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - // mamba(2) ssm layers - { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, - { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, - { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, - { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, - { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, - { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, - { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, - // attention layers - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - // dense FFN - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - // moe FFN - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - // shared expert - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - }, - }, - { - LLM_ARCH_CHAMELEON, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - }, - }, - { - LLM_ARCH_SOLAR, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_BSKCN_TV, "bskcn_tv" }, - }, - }, - { - LLM_ARCH_WAVTOKENIZER_DEC, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, - { LLM_TENSOR_CONV1D, "conv1d" }, - { LLM_TENSOR_CONVNEXT_DW, "convnext.%d.dw" }, - { LLM_TENSOR_CONVNEXT_NORM, "convnext.%d.norm" }, - { LLM_TENSOR_CONVNEXT_PW1, "convnext.%d.pw1" }, - { LLM_TENSOR_CONVNEXT_PW2, "convnext.%d.pw2" }, - { LLM_TENSOR_CONVNEXT_GAMMA, "convnext.%d.gamma" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_POS_NET_CONV1, "posnet.%d.conv1" }, - { LLM_TENSOR_POS_NET_CONV2, "posnet.%d.conv2" }, - { LLM_TENSOR_POS_NET_NORM, "posnet.%d.norm" }, - { LLM_TENSOR_POS_NET_NORM1, "posnet.%d.norm1" }, - { LLM_TENSOR_POS_NET_NORM2, "posnet.%d.norm2" }, - { LLM_TENSOR_POS_NET_ATTN_NORM, "posnet.%d.attn_norm" }, - { LLM_TENSOR_POS_NET_ATTN_Q, "posnet.%d.attn_q" }, - { LLM_TENSOR_POS_NET_ATTN_K, "posnet.%d.attn_k" }, - { LLM_TENSOR_POS_NET_ATTN_V, "posnet.%d.attn_v" }, - { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, - }, - }, - { - LLM_ARCH_BAILINGMOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - }, - }, - { - LLM_ARCH_BAILINGMOE2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, - { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, - { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, - { LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" }, - { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" }, - { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, - { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, - }, - }, - { - LLM_ARCH_DOTS1, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, - } - }, - { - LLM_ARCH_ERNIE4_5, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_ERNIE4_5_MOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, - }, - }, - { - LLM_ARCH_HUNYUAN_MOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, - { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, - { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_HUNYUAN_DENSE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - - }, - }, - { - LLM_ARCH_SMOLLM3, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_OPENAI_MOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_LFM2, - { - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" }, - { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, - { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name - { LLM_TENSOR_OUTPUT, "output" }, - } - }, - { - LLM_ARCH_LFM2MOE, - { - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" }, - { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, - { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, - } - }, - { - LLM_ARCH_SMALLTHINKER, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" } - }, - }, - { - LLM_ARCH_APERTUS, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_DREAM, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_LLADA, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_LLADA_MOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_SEED_OSS, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_GROVEMOE, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_GATE_CHEXPS, "blk.%d.ffn_gate_chexps" }, - { LLM_TENSOR_FFN_DOWN_CHEXPS, "blk.%d.ffn_down_chexps" }, - { LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" }, - }, - }, - { - LLM_ARCH_MINIMAX_M2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, - }, - }, - { - LLM_ARCH_PANGU_EMBED, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_COGVLM, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_VISEXP_ATTN_QKV, "blk.%d.vis_attn_qkv" }, - { LLM_TENSOR_VISEXP_ATTN_OUT, "blk.%d.vis_attn_output" }, - { LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" }, - { LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" }, - { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, - }, - }, - { - LLM_ARCH_RND1, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_MISTRAL3, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, - { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, - { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - }, - }, - { - LLM_ARCH_UNKNOWN, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - }, - }, +static const std::map LLM_TENSOR_NAMES = { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT_NORM_LFM2, "token_embd_norm" }, // fix for wrong tensor name + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_BSKCN_TV, "bskcn_tv" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" }, + { LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" }, + { LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" }, + { LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" }, + { LLM_TENSOR_ALTUP_PROJ, "altup_proj" }, + { LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" }, + { LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" }, + { LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" }, + { LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" }, + { LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" }, + { LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" }, + { LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" }, + { LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" }, + { LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" }, + { LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" }, + { LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" }, + { LLM_TENSOR_DENSE_2_OUT, "dense_2" }, + { LLM_TENSOR_DENSE_3_OUT, "dense_3" }, + { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, + { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, + { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, + { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, + { LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + { LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" }, + { LLM_TENSOR_DEC_ATTN_NORM, "dec.blk.%d.attn_norm" }, + { LLM_TENSOR_DEC_ATTN_Q, "dec.blk.%d.attn_q" }, + { LLM_TENSOR_DEC_ATTN_K, "dec.blk.%d.attn_k" }, + { LLM_TENSOR_DEC_ATTN_V, "dec.blk.%d.attn_v" }, + { LLM_TENSOR_DEC_ATTN_OUT, "dec.blk.%d.attn_o" }, + { LLM_TENSOR_DEC_ATTN_REL_B, "dec.blk.%d.attn_rel_b" }, + { LLM_TENSOR_DEC_CROSS_ATTN_NORM, "dec.blk.%d.cross_attn_norm" }, + { LLM_TENSOR_DEC_CROSS_ATTN_Q, "dec.blk.%d.cross_attn_q" }, + { LLM_TENSOR_DEC_CROSS_ATTN_K, "dec.blk.%d.cross_attn_k" }, + { LLM_TENSOR_DEC_CROSS_ATTN_V, "dec.blk.%d.cross_attn_v" }, + { LLM_TENSOR_DEC_CROSS_ATTN_OUT, "dec.blk.%d.cross_attn_o" }, + { LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" }, + { LLM_TENSOR_DEC_FFN_NORM, "dec.blk.%d.ffn_norm" }, + { LLM_TENSOR_DEC_FFN_GATE, "dec.blk.%d.ffn_gate" }, + { LLM_TENSOR_DEC_FFN_DOWN, "dec.blk.%d.ffn_down" }, + { LLM_TENSOR_DEC_FFN_UP, "dec.blk.%d.ffn_up" }, + { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, + { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, + { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, + { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, + { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, + { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, + { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, + { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, + { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, + { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, + { LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" }, + { LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" }, + { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" }, + { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" }, + { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, + { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, + { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, + { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, + { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_R, "blk.%d.channel_mix_lerp_r" }, + { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" }, + { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" }, + { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" }, + { LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" }, + { LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" }, + { LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" }, + { LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" }, + { LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" }, + { LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" }, + { LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" }, + { LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" }, + { LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" }, + { LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" }, + { LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" }, + { LLM_TENSOR_CONV1D, "conv1d" }, + { LLM_TENSOR_CONVNEXT_DW, "convnext.%d.dw" }, + { LLM_TENSOR_CONVNEXT_NORM, "convnext.%d.norm" }, + { LLM_TENSOR_CONVNEXT_PW1, "convnext.%d.pw1" }, + { LLM_TENSOR_CONVNEXT_PW2, "convnext.%d.pw2" }, + { LLM_TENSOR_CONVNEXT_GAMMA, "convnext.%d.gamma" }, + { LLM_TENSOR_POS_NET_CONV1, "posnet.%d.conv1" }, + { LLM_TENSOR_POS_NET_CONV2, "posnet.%d.conv2" }, + { LLM_TENSOR_POS_NET_NORM, "posnet.%d.norm" }, + { LLM_TENSOR_POS_NET_NORM1, "posnet.%d.norm1" }, + { LLM_TENSOR_POS_NET_NORM2, "posnet.%d.norm2" }, + { LLM_TENSOR_POS_NET_ATTN_NORM, "posnet.%d.attn_norm" }, + { LLM_TENSOR_POS_NET_ATTN_Q, "posnet.%d.attn_q" }, + { LLM_TENSOR_POS_NET_ATTN_K, "posnet.%d.attn_k" }, + { LLM_TENSOR_POS_NET_ATTN_V, "posnet.%d.attn_v" }, + { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, + { LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" }, + { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" }, + { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, + { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, + { LLM_TENSOR_FFN_GATE_CHEXPS, "blk.%d.ffn_gate_chexps" }, + { LLM_TENSOR_FFN_DOWN_CHEXPS, "blk.%d.ffn_down_chexps" }, + { LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" }, + { LLM_TENSOR_VISEXP_ATTN_QKV, "blk.%d.vis_attn_qkv" }, + { LLM_TENSOR_VISEXP_ATTN_OUT, "blk.%d.vis_attn_output" }, + { LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" }, + { LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" }, + { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, }; +static std::set llm_get_tensor_names(llm_arch arch) { + switch (arch) { + case LLM_ARCH_CLIP: + return {}; + case LLM_ARCH_LLAMA: + case LLM_ARCH_DECI: + case LLM_ARCH_MISTRAL3: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_EXP, + LLM_TENSOR_FFN_DOWN_EXP, + LLM_TENSOR_FFN_UP_EXP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + }; + case LLM_ARCH_ARCEE: + case LLM_ARCH_STARCODER2: + case LLM_ARCH_NEMOTRON: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_AFMOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; + case LLM_ARCH_LLAMA4: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_EXP, + LLM_TENSOR_FFN_DOWN_EXP, + LLM_TENSOR_FFN_UP_EXP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + }; + case LLM_ARCH_BAICHUAN: + case LLM_ARCH_ORION: + case LLM_ARCH_XVERSE: + case LLM_ARCH_EXAONE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_FALCON: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_GROK: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_EXP, + LLM_TENSOR_FFN_DOWN_EXP, + LLM_TENSOR_FFN_UP_EXP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_ATTN_OUT_NORM, + }; + case LLM_ARCH_GPT2: + case LLM_ARCH_STARCODER: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_POS_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; + case LLM_ARCH_GPTNEOX: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_MPT: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_ACT, + LLM_TENSOR_POS_EMBD, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + }; + case LLM_ARCH_REFACT: + case LLM_ARCH_QWEN2: + case LLM_ARCH_QWEN2VL: + case LLM_ARCH_INTERNLM2: + case LLM_ARCH_GRANITE: + case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_SMOLLM3: + case LLM_ARCH_DREAM: + case LLM_ARCH_LLADA: + case LLM_ARCH_PANGU_EMBED: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_BERT: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, + LLM_TENSOR_POS_EMBD, + LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, + }; + case LLM_ARCH_NOMIC_BERT: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, + LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_NOMIC_BERT_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, + LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + }; + case LLM_ARCH_NEO_BERT: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, + }; + case LLM_ARCH_JINA_BERT_V2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_CLS, + }; + case LLM_ARCH_JINA_BERT_V3: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, + LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_LAYER_OUT_NORM, + }; + case LLM_ARCH_BLOOM: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; + case LLM_ARCH_STABLELM: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + }; + case LLM_ARCH_QWEN: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_QWEN2MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + }; + case LLM_ARCH_QWEN3: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_CLS_OUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_QWEN3MOE: + case LLM_ARCH_QWEN3VLMOE: + case LLM_ARCH_OLMOE: + case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + }; + case LLM_ARCH_QWEN3NEXT: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA_ALPHA, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; + case LLM_ARCH_QWEN3VL: + case LLM_ARCH_CHAMELEON: + case LLM_ARCH_HUNYUAN_DENSE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_PHI2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_PHI3: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_PHIMOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + }; + case LLM_ARCH_PLAMO: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_PLAMO2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_X, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_OUT, + LLM_TENSOR_SSM_DT_NORM, + LLM_TENSOR_SSM_B_NORM, + LLM_TENSOR_SSM_C_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_FFN_POST_NORM, + }; + case LLM_ARCH_CODESHELL: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_MINICPM: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_EXP, + LLM_TENSOR_FFN_DOWN_EXP, + LLM_TENSOR_FFN_UP_EXP, + }; + case LLM_ARCH_MINICPM3: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; + case LLM_ARCH_GEMMA: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_GEMMA2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_POST_NORM, + }; + case LLM_ARCH_GEMMA3: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_POST_NORM, + }; + case LLM_ARCH_GEMMA3N: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_PER_LAYER_TOKEN_EMBD, + LLM_TENSOR_PER_LAYER_MODEL_PROJ, + LLM_TENSOR_PER_LAYER_PROJ_NORM, + LLM_TENSOR_ALTUP_UNEMBD_PROJ, + LLM_TENSOR_ALTUP_PROJ, + LLM_TENSOR_PER_LAYER_INP_GATE, + LLM_TENSOR_PER_LAYER_PROJ, + LLM_TENSOR_PER_LAYER_POST_NORM, + LLM_TENSOR_ALTUP_CORRECT_COEF, + LLM_TENSOR_ALTUP_CORRECT_SCALE, + LLM_TENSOR_ALTUP_PREDICT_COEF, + LLM_TENSOR_ALTUP_ROUTER, + LLM_TENSOR_ALTUP_ROUTER_NORM, + LLM_TENSOR_LAUREL_L, + LLM_TENSOR_LAUREL_R, + LLM_TENSOR_LAUREL_POST_NORM, + }; + case LLM_ARCH_GEMMA_EMBEDDING: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_DENSE_2_OUT, + LLM_TENSOR_DENSE_3_OUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_POST_NORM, + }; + case LLM_ARCH_MAMBA: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_X, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_OUT, + }; + case LLM_ARCH_MAMBA2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; + case LLM_ARCH_JAMBA: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_X, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_DT_NORM, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_B_NORM, + LLM_TENSOR_SSM_C_NORM, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_OUT, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + }; + case LLM_ARCH_FALCON_H1: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_COMMAND_R: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + }; + case LLM_ARCH_COHERE2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_DBRX: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + }; + case LLM_ARCH_OLMO: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_OLMO2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_OPENELM: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_ARCTIC: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_NORM_EXPS, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + }; + case LLM_ARCH_DEEPSEEK: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + }; + case LLM_ARCH_DEEPSEEK2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; + case LLM_ARCH_PLM: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_CHATGLM: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; + case LLM_ARCH_GLM4: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_FFN_POST_NORM, + }; + case LLM_ARCH_GLM4_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + }; + case LLM_ARCH_BITNET: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_SUB_NORM, + }; + case LLM_ARCH_T5: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_DEC_OUTPUT_NORM, + LLM_TENSOR_DEC_ATTN_NORM, + LLM_TENSOR_DEC_ATTN_Q, + LLM_TENSOR_DEC_ATTN_K, + LLM_TENSOR_DEC_ATTN_V, + LLM_TENSOR_DEC_ATTN_OUT, + LLM_TENSOR_DEC_ATTN_REL_B, + LLM_TENSOR_DEC_CROSS_ATTN_NORM, + LLM_TENSOR_DEC_CROSS_ATTN_Q, + LLM_TENSOR_DEC_CROSS_ATTN_K, + LLM_TENSOR_DEC_CROSS_ATTN_V, + LLM_TENSOR_DEC_CROSS_ATTN_OUT, + LLM_TENSOR_DEC_CROSS_ATTN_REL_B, + LLM_TENSOR_DEC_FFN_NORM, + LLM_TENSOR_DEC_FFN_GATE, + LLM_TENSOR_DEC_FFN_DOWN, + LLM_TENSOR_DEC_FFN_UP, + LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_ENC_ATTN_NORM, + LLM_TENSOR_ENC_ATTN_Q, + LLM_TENSOR_ENC_ATTN_K, + LLM_TENSOR_ENC_ATTN_V, + LLM_TENSOR_ENC_ATTN_OUT, + LLM_TENSOR_ENC_ATTN_REL_B, + LLM_TENSOR_ENC_FFN_NORM, + LLM_TENSOR_ENC_FFN_GATE, + LLM_TENSOR_ENC_FFN_DOWN, + LLM_TENSOR_ENC_FFN_UP, + }; + case LLM_ARCH_T5ENCODER: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_ENC_ATTN_NORM, + LLM_TENSOR_ENC_ATTN_Q, + LLM_TENSOR_ENC_ATTN_K, + LLM_TENSOR_ENC_ATTN_V, + LLM_TENSOR_ENC_ATTN_OUT, + LLM_TENSOR_ENC_ATTN_REL_B, + LLM_TENSOR_ENC_FFN_NORM, + LLM_TENSOR_ENC_FFN_GATE, + LLM_TENSOR_ENC_FFN_DOWN, + LLM_TENSOR_ENC_FFN_UP, + }; + case LLM_ARCH_JAIS: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + }; + case LLM_ARCH_NEMOTRON_H: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_NEMOTRON_H_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + // mamba(2) ssm layers + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + // attention layers + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + // dense FFN + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + // MoE FFN (for MoE layers) + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_EXP_PROBS_B, + // MoE shared expert layer + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + }; + case LLM_ARCH_EXAONE4: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_POST_NORM, + }; + case LLM_ARCH_RWKV6: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_TIME_MIX_W1, + LLM_TENSOR_TIME_MIX_W2, + LLM_TENSOR_TIME_MIX_LERP_X, + LLM_TENSOR_TIME_MIX_LERP_W, + LLM_TENSOR_TIME_MIX_LERP_K, + LLM_TENSOR_TIME_MIX_LERP_V, + LLM_TENSOR_TIME_MIX_LERP_R, + LLM_TENSOR_TIME_MIX_LERP_G, + LLM_TENSOR_TIME_MIX_LERP_FUSED, + LLM_TENSOR_TIME_MIX_FIRST, + LLM_TENSOR_TIME_MIX_DECAY, + LLM_TENSOR_TIME_MIX_DECAY_W1, + LLM_TENSOR_TIME_MIX_DECAY_W2, + LLM_TENSOR_TIME_MIX_KEY, + LLM_TENSOR_TIME_MIX_VALUE, + LLM_TENSOR_TIME_MIX_RECEPTANCE, + LLM_TENSOR_TIME_MIX_GATE, + LLM_TENSOR_TIME_MIX_LN, + LLM_TENSOR_TIME_MIX_OUTPUT, + LLM_TENSOR_CHANNEL_MIX_LERP_K, + LLM_TENSOR_CHANNEL_MIX_LERP_R, + LLM_TENSOR_CHANNEL_MIX_KEY, + LLM_TENSOR_CHANNEL_MIX_VALUE, + LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, + }; + case LLM_ARCH_RWKV6QWEN2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_TIME_MIX_W1, + LLM_TENSOR_TIME_MIX_W2, + LLM_TENSOR_TIME_MIX_LERP_X, + LLM_TENSOR_TIME_MIX_LERP_FUSED, + LLM_TENSOR_TIME_MIX_FIRST, + LLM_TENSOR_TIME_MIX_DECAY, + LLM_TENSOR_TIME_MIX_DECAY_W1, + LLM_TENSOR_TIME_MIX_DECAY_W2, + LLM_TENSOR_TIME_MIX_KEY, + LLM_TENSOR_TIME_MIX_VALUE, + LLM_TENSOR_TIME_MIX_RECEPTANCE, + LLM_TENSOR_TIME_MIX_GATE, + LLM_TENSOR_TIME_MIX_OUTPUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_RWKV7: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_TIME_MIX_W0, + LLM_TENSOR_TIME_MIX_W1, + LLM_TENSOR_TIME_MIX_W2, + LLM_TENSOR_TIME_MIX_A0, + LLM_TENSOR_TIME_MIX_A1, + LLM_TENSOR_TIME_MIX_A2, + LLM_TENSOR_TIME_MIX_V0, + LLM_TENSOR_TIME_MIX_V1, + LLM_TENSOR_TIME_MIX_V2, + LLM_TENSOR_TIME_MIX_G1, + LLM_TENSOR_TIME_MIX_G2, + LLM_TENSOR_TIME_MIX_K_K, + LLM_TENSOR_TIME_MIX_K_A, + LLM_TENSOR_TIME_MIX_R_K, + LLM_TENSOR_TIME_MIX_LERP_FUSED, + LLM_TENSOR_TIME_MIX_KEY, + LLM_TENSOR_TIME_MIX_VALUE, + LLM_TENSOR_TIME_MIX_RECEPTANCE, + LLM_TENSOR_TIME_MIX_LN, + LLM_TENSOR_TIME_MIX_OUTPUT, + LLM_TENSOR_CHANNEL_MIX_LERP_K, + LLM_TENSOR_CHANNEL_MIX_KEY, + LLM_TENSOR_CHANNEL_MIX_VALUE, + }; + case LLM_ARCH_ARWKV7: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_TIME_MIX_W0, + LLM_TENSOR_TIME_MIX_W1, + LLM_TENSOR_TIME_MIX_W2, + LLM_TENSOR_TIME_MIX_A0, + LLM_TENSOR_TIME_MIX_A1, + LLM_TENSOR_TIME_MIX_A2, + LLM_TENSOR_TIME_MIX_V0, + LLM_TENSOR_TIME_MIX_V1, + LLM_TENSOR_TIME_MIX_V2, + LLM_TENSOR_TIME_MIX_G1, + LLM_TENSOR_TIME_MIX_G2, + LLM_TENSOR_TIME_MIX_K_K, + LLM_TENSOR_TIME_MIX_K_A, + LLM_TENSOR_TIME_MIX_R_K, + LLM_TENSOR_TIME_MIX_LERP_FUSED, + LLM_TENSOR_TIME_MIX_KEY, + LLM_TENSOR_TIME_MIX_VALUE, + LLM_TENSOR_TIME_MIX_RECEPTANCE, + LLM_TENSOR_TIME_MIX_LN, + LLM_TENSOR_TIME_MIX_OUTPUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_GRANITE_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + }; + case LLM_ARCH_GRANITE_HYBRID: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + }; + case LLM_ARCH_WAVTOKENIZER_DEC: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_CONV1D, + LLM_TENSOR_CONVNEXT_DW, + LLM_TENSOR_CONVNEXT_NORM, + LLM_TENSOR_CONVNEXT_PW1, + LLM_TENSOR_CONVNEXT_PW2, + LLM_TENSOR_CONVNEXT_GAMMA, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_POS_NET_CONV1, + LLM_TENSOR_POS_NET_CONV2, + LLM_TENSOR_POS_NET_NORM, + LLM_TENSOR_POS_NET_NORM1, + LLM_TENSOR_POS_NET_NORM2, + LLM_TENSOR_POS_NET_ATTN_NORM, + LLM_TENSOR_POS_NET_ATTN_Q, + LLM_TENSOR_POS_NET_ATTN_K, + LLM_TENSOR_POS_NET_ATTN_V, + LLM_TENSOR_POS_NET_ATTN_OUT, + }; + case LLM_ARCH_BAILINGMOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + }; + case LLM_ARCH_BAILINGMOE2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + LLM_TENSOR_LAYER_OUT_NORM, + }; + case LLM_ARCH_DOTS1: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; + case LLM_ARCH_ERNIE4_5_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; + case LLM_ARCH_HUNYUAN_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + }; + case LLM_ARCH_OPENAI_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_SINKS, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + }; + case LLM_ARCH_LFM2: + return { + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_SHORTCONV_CONV, + LLM_TENSOR_SHORTCONV_INPROJ, + LLM_TENSOR_SHORTCONV_OUTPROJ, + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM_LFM2, + LLM_TENSOR_OUTPUT, + }; + case LLM_ARCH_LFM2MOE: + return { + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_SHORTCONV_CONV, + LLM_TENSOR_SHORTCONV_INPROJ, + LLM_TENSOR_SHORTCONV_OUTPROJ, + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; + case LLM_ARCH_SMALLTHINKER: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + }; + case LLM_ARCH_APERTUS: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_SEED_OSS: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_GROVEMOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_CHEXPS, + LLM_TENSOR_FFN_DOWN_CHEXPS, + LLM_TENSOR_FFN_UP_CHEXPS, + }; + case LLM_ARCH_MINIMAX_M2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; + case LLM_ARCH_COGVLM: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_VISEXP_ATTN_QKV, + LLM_TENSOR_VISEXP_ATTN_OUT, + LLM_TENSOR_VISEXP_FFN_GATE, + LLM_TENSOR_VISEXP_FFN_DOWN, + LLM_TENSOR_VISEXP_FFN_UP, + }; + case LLM_ARCH_GPTJ: + case LLM_ARCH_UNKNOWN: + return { + LLM_TENSOR_TOKEN_EMBD, + }; + case LLM_ARCH_SOLAR: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_BSKCN_TV, + }; + default: + GGML_ABORT("unknown architecture for tensor mapping"); + } +} + // declare information about the model weight tensors: // - the layer in which the tensor is going to be used. this is needed in order to assign the correct buffer type for the weight // - the operator which is going to use the weight. this is needed to determine if the respective backend supports the operator @@ -2589,6 +2221,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, @@ -2778,13 +2411,20 @@ std::string LLM_KV::operator()(llm_kv kv) const { return name; } +LLM_TN_IMPL::LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid) + : arch(arch), tensor(tensor), suffix(suffix), bid(bid), xid(xid), + model_tensors(llm_get_tensor_names(arch)) {} + std::string LLM_TN_IMPL::str() const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { - return "__missing__"; + if (LLM_TENSOR_NAMES.find(tensor) == LLM_TENSOR_NAMES.end()) { + GGML_ABORT("unknown tensor name for tensor id %d", static_cast(tensor)); } - std::string name = ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid, xid); + if (model_tensors.find(tensor) == model_tensors.end()) { + return LLM_TENSOR_NAMES.at(tensor); + } + std::string name = ::format(LLM_TENSOR_NAMES.at(tensor), bid, xid); if (suffix != nullptr) { name += "."; name += suffix; @@ -2838,6 +2478,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: return true; default: diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index ec9e3a6d..14d461c7 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -3,6 +3,7 @@ #include "ggml.h" // ggml_op #include +#include // // gguf constants (sync with gguf.py) @@ -79,6 +80,7 @@ enum llm_arch { LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON_H, + LLM_ARCH_NEMOTRON_H_MOE, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, LLM_ARCH_RWKV6, @@ -317,6 +319,7 @@ enum llm_tensor { LLM_TENSOR_DENSE_3_OUT, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT_NORM_LFM2, // fix for wrong tensor name LLM_TENSOR_ROPE_FREQS, LLM_TENSOR_ROPE_FACTORS_LONG, LLM_TENSOR_ROPE_FACTORS_SHORT, @@ -528,6 +531,10 @@ struct LLM_TN_IMPL { const int bid; const int xid; + const std::set model_tensors; + + LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid); + std::string str() const; operator std::string() const { @@ -549,11 +556,11 @@ struct LLM_TN { llm_arch arch; LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const { - return { arch, tensor, suffix, bid, xid }; + return LLM_TN_IMPL(arch, tensor, suffix, bid, xid); } LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const { - return { arch, tensor, nullptr, bid, xid }; + return LLM_TN_IMPL(arch, tensor, nullptr, bid, xid); } }; diff --git a/llama/llama.cpp/src/llama-batch.cpp b/llama/llama.cpp/src/llama-batch.cpp index 86a1a4ba..386fab04 100644 --- a/llama/llama.cpp/src/llama-batch.cpp +++ b/llama/llama.cpp/src/llama-batch.cpp @@ -695,6 +695,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u udata->seq_idx .resize(LLAMA_MAX_SEQ, -1); udata->output .resize(n_tokens); + udata->seq_id_data.reserve(n_tokens); + seq_set_t seq_set_unq; for (size_t i = 0; i < idxs.size(); ++i) { @@ -716,11 +718,13 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u } udata->n_seq_id[i] = batch.n_seq_id[idxs[i]]; - udata->seq_id[i] = batch.seq_id[idxs[i]]; udata->output[i] = batch.logits[idxs[i]]; for (int s = 0; s < udata->n_seq_id[i]; ++s) { - seq_set_unq.set(udata->seq_id[i][s]); + const llama_seq_id seq_id = batch.seq_id[idxs[i]][s]; + + udata->seq_id_data.push_back(seq_id); + seq_set_unq.set(seq_id); } if (udata->output[i]) { @@ -728,6 +732,12 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u } } + llama_seq_id * seq_id_ptr = udata->seq_id_data.data(); + for (size_t i = 0; i < idxs.size(); ++i) { + udata->seq_id[i] = seq_id_ptr; + seq_id_ptr += udata->n_seq_id[i]; + } + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_set_unq.test(s)) { udata->seq_idx[s] = udata->seq_id_unq.size(); diff --git a/llama/llama.cpp/src/llama-batch.h b/llama/llama.cpp/src/llama-batch.h index 209cf369..8e6fac0e 100644 --- a/llama/llama.cpp/src/llama-batch.h +++ b/llama/llama.cpp/src/llama-batch.h @@ -56,13 +56,15 @@ struct llama_ubatch { std::vector embd; std::vector pos; std::vector n_seq_id; - std::vector seq_id; + std::vector seq_id; // these point into the seq_id_data below std::vector seq_id_unq; std::vector seq_idx; std::vector output; + + std::vector seq_id_data; }; - // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data + // the llama_ubatch pointers above point to this data if set. otherwise - point to external non-owning data std::shared_ptr data; }; diff --git a/llama/llama.cpp/src/llama-context.cpp b/llama/llama.cpp/src/llama-context.cpp index 87f407f9..9e699827 100644 --- a/llama/llama.cpp/src/llama-context.cpp +++ b/llama/llama.cpp/src/llama-context.cpp @@ -9,6 +9,7 @@ #include "llama-model.h" #include +#include #include #include #include @@ -72,6 +73,43 @@ llama_context::llama_context( cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; } + if (cparams.yarn_ext_factor != 0) { + static auto get_mscale = [](float scale, float mscale) { + return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); + }; + + const float factor = 1.0f / cparams.rope_freq_scale; + + // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348 + if (hparams.rope_yarn_log_mul != 0.0f) { + // note: here we assume `mscale == 1.0f` + // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f + float mscale = 1.0f; + const float mscale_all_dims = hparams.rope_yarn_log_mul; + + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // special-case DEEPSEEK v2: + // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43 + if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) { + mscale = mscale_all_dims; + } + + cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); + + LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n", + __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims); + } else { + cparams.yarn_attn_factor = get_mscale(factor, 1.0f); + } + + // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor: + // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544 + // + // ref: https://github.com/ggml-org/llama.cpp/discussions/7416 + // https://github.com/ggml-org/llama.cpp/pull/17945 + cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor)); + } + cparams.yarn_attn_factor *= hparams.rope_attn_factor; if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { @@ -93,14 +131,6 @@ llama_context::llama_context( // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; - // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask - // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) - // ref: https://github.com/ggerganov/llama.cpp/pull/5021 - // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory - if (cparams.n_batch < GGML_KQ_MASK_PAD) { - LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); - cparams.n_batch = GGML_KQ_MASK_PAD; - } cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.op_offload = params.op_offload; @@ -228,6 +258,7 @@ llama_context::llama_context( backend_buft.clear(); backend_ptrs.clear(); + backend_buf_exp_size.clear(); for (auto & backend : backends) { auto * buft = ggml_backend_get_default_buffer_type(backend.get()); @@ -244,6 +275,7 @@ llama_context::llama_context( backend_buft.push_back(buft); backend_ptrs.push_back(backend.get()); + backend_buf_exp_size.push_back(0); } LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); @@ -359,7 +391,8 @@ llama_context::llama_context( // reserve pp (prompt processing) graph first so that buffers are only allocated once { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), + model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); if (!gf) { if (pipeline_parallel) { LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); @@ -377,7 +410,7 @@ llama_context::llama_context( // reserve with tg (token generation) graph to get the number of splits and nodes { - auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get()); + auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -392,7 +425,7 @@ llama_context::llama_context( // // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); // - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -401,11 +434,13 @@ llama_context::llama_context( for (size_t i = 0; i < backend_ptrs.size(); ++i) { ggml_backend_t backend = backend_ptrs[i]; ggml_backend_buffer_type_t buft = backend_buft[i]; - size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend); - if (size > 1) { + if (!model.hparams.no_alloc) { + backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); + } + if (backend_buf_exp_size[i] > 1) { LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, ggml_backend_buft_name(buft), - size / 1024.0 / 1024.0); + backend_buf_exp_size[i] / 1024.0 / 1024.0); } } @@ -424,6 +459,23 @@ llama_context::llama_context( } llama_context::~llama_context() { + // FIXME this currently results in a use-after-free bug if the model is freed before the context + // if (!model.hparams.no_alloc) { + // for (size_t i = 0; i < backend_ptrs.size(); ++i) { + // ggml_backend_t backend = backend_ptrs[i]; + // ggml_backend_buffer_type_t buft = backend_buft[i]; + + // const size_t size_exp = backend_buf_exp_size[i]; + // const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); + // if (size_exp == size_act) { + // LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", + // __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); + // } else { + // LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", + // __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); + // } + // } + // } ggml_opt_free(opt_ctx); } @@ -1325,6 +1377,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif + synchronize(); buf_output = nullptr; logits = nullptr; embd = nullptr; @@ -1396,7 +1449,8 @@ llm_graph_result * llama_context::get_gf_res_reserve() const { return static_cast(gf_res_reserve.get()); } -ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) { +ggml_cgraph * llama_context::graph_reserve( + uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) { LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); GGML_ASSERT(n_outputs >= 1); @@ -1433,8 +1487,13 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u // initialize scheduler with the specified graph if (split_only) { - ggml_backend_sched_split_graph(sched.get(), gf); + if (sizes) { + ggml_backend_sched_reserve_size(sched.get(), gf, sizes); + } else { + ggml_backend_sched_split_graph(sched.get(), gf); + } } else if (!ggml_backend_sched_reserve(sched.get(), gf)) { + GGML_ASSERT(!sizes); LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); return nullptr; } @@ -2056,15 +2115,26 @@ void llama_context::perf_reset() { std::map llama_context::memory_breakdown() const { std::map ret; - for (const auto & buft_size : model.memory_breakdown()) { - ret[buft_size.first].model += buft_size.second; + for (const auto & [buft, size] : model.memory_breakdown()) { + ret[buft].model += size; } - for (const auto & buft_size : memory->memory_breakdown()) { - ret[buft_size.first].context += buft_size.second; + if (memory) { + for (const auto & [buft, size] : memory->memory_breakdown()) { + ret[buft].context += size; + } } - for (const auto & backend_ptr : backends) { - ggml_backend_t backend = backend_ptr.get(); - ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend); + if (model.hparams.no_alloc) { + for (size_t i = 0; i < backends.size(); ++i) { + ggml_backend_t backend = backends[i].get(); + ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend); + ret[buft].compute += backend_buf_exp_size[i]; + } + } else { + for (const auto & backend_ptr : backends) { + ggml_backend_t backend = backend_ptr.get(); + ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend); + ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend); + } } return ret; } diff --git a/llama/llama.cpp/src/llama-context.h b/llama/llama.cpp/src/llama-context.h index cd26eafe..c3110133 100644 --- a/llama/llama.cpp/src/llama-context.h +++ b/llama/llama.cpp/src/llama-context.h @@ -26,6 +26,10 @@ struct llama_memory_breakdown_data { size_t model = 0; // memory allocated for the model size_t context = 0; // memory allocated for the context size_t compute = 0; // memory allocated for temporary compute buffers + + size_t total() const { + return model + context + compute; + } }; struct llama_context { @@ -206,7 +210,8 @@ public: ggml_status graph_compute(ggml_cgraph * gf, bool batched); // reserve a graph with a dummy ubatch of the specified size - ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false); + ggml_cgraph * graph_reserve( + uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr); private: llm_graph_params graph_params( @@ -281,9 +286,10 @@ private: std::vector> set_n_threads_fns; - // buffer types used for the compute buffer of each backend + // pointers and buffer types used for the compute buffer of each backend std::vector backend_ptrs; std::vector backend_buft; + std::vector backend_buf_exp_size; // expected buffer sizes llm_graph_result_ptr gf_res_prev; llm_graph_result_ptr gf_res_reserve; diff --git a/llama/llama.cpp/src/llama-graph.cpp b/llama/llama.cpp/src/llama-graph.cpp index 43620df7..1d0d7197 100644 --- a/llama/llama.cpp/src/llama-graph.cpp +++ b/llama/llama.cpp/src/llama-graph.cpp @@ -78,7 +78,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { for (int i = 0; i < n_tokens; ++i) { const float pos = ubatch->pos[i]; attn_scale_data[i] = std::log( - std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0 + std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0 ) * f_attn_temp_scale + 1.0; } @@ -254,6 +254,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= s_copy->ne[0] == mctx->get_n_rs(); + + res &= s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); + + return res; +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -385,7 +403,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; return res; } @@ -416,10 +434,10 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv(); - res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); - res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; return res; } @@ -452,7 +470,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int i = n_tokens; i < n_tokens; ++i) { for (int j = 0; j < n_enc; ++j) { data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY; } @@ -461,8 +479,46 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - inp_attn->set_input(ubatch); - inp_rs->set_input(ubatch); + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; } // @@ -1089,6 +1145,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_relu(ctx0, cur); cb(cur, "ffn_moe_relu", il); } break; + case LLM_FFN_RELU_SQR: + if (gate_exps) { + // TODO: add support for gated squared relu + GGML_ABORT("fatal error: gated squared relu not implemented"); + } else { + cur = ggml_relu(ctx0, cur); + cur = ggml_sqr(ctx0, cur); + cb(cur, "ffn_moe_relu_sqr", il); + } break; default: GGML_ABORT("fatal error"); } @@ -1203,7 +1268,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const { } ggml_tensor * llm_graph_context::build_inp_attn_scale() const { - auto inp = std::make_unique(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale); + auto inp = std::make_unique(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset); auto & cur = inp->attn_scale; @@ -1470,13 +1535,13 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con auto inp = std::make_unique(hparams, cparams); // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; @@ -1558,7 +1623,7 @@ static std::unique_ptr build_attn_inp_kv_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1701,7 +1766,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; - inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1); ggml_set_input(inp->cross_kq_mask); inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; @@ -1767,7 +1832,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1781,7 +1846,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; @@ -1841,6 +1906,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } @@ -1909,10 +1977,10 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); - auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); - auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } diff --git a/llama/llama.cpp/src/llama-graph.h b/llama/llama.cpp/src/llama-graph.h index d0c3934f..81ac329c 100644 --- a/llama/llama.cpp/src/llama-graph.h +++ b/llama/llama.cpp/src/llama-graph.h @@ -132,8 +132,8 @@ public: // temperature tuning, used by llama4 class llm_graph_input_attn_temp : public llm_graph_input_i { public: - llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale) - : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {} + llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset) + : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {} virtual ~llm_graph_input_attn_temp() = default; void set_input(const llama_ubatch * ubatch) override; @@ -142,6 +142,7 @@ public: const uint32_t n_attn_temp_floor_scale; const float f_attn_temp_scale; + const float f_attn_temp_offset; }; class llm_graph_input_pos_bucket : public llm_graph_input_i { @@ -224,6 +225,8 @@ public: void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * s_copy; // I32 [n_rs] // views of s_copy, computed once per graph @@ -232,6 +235,10 @@ public: ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs] const llama_memory_recurrent_context * mctx; + + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -364,22 +371,28 @@ public: class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( + const llama_cparams & cparams, std::unique_ptr inp_attn, - std::unique_ptr inp_rs, - const llama_memory_hybrid_context * mctx) : + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : inp_attn(std::move(inp_attn)), inp_rs(std::move(inp_rs)), + cparams(cparams), mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + std::unique_ptr inp_attn; std::unique_ptr inp_rs; llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + const llama_cparams cparams; + const llama_memory_hybrid_context * mctx; }; diff --git a/llama/llama.cpp/src/llama-hparams.cpp b/llama/llama.cpp/src/llama-hparams.cpp index 41127bf9..aabff2f0 100644 --- a/llama/llama.cpp/src/llama-hparams.cpp +++ b/llama/llama.cpp/src/llama-hparams.cpp @@ -1,6 +1,8 @@ #include "llama-hparams.h" #include "ggml.h" + +#include #include void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { @@ -237,3 +239,7 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama return false; } + +bool llama_hparams::use_mrope() const { + return rope_sections[0] > 0 && rope_sections[1] > 0; +} diff --git a/llama/llama.cpp/src/llama-hparams.h b/llama/llama.cpp/src/llama-hparams.h index a778fc3c..c6e67327 100644 --- a/llama/llama.cpp/src/llama-hparams.h +++ b/llama/llama.cpp/src/llama-hparams.h @@ -34,6 +34,7 @@ struct llama_hparams_convnext { struct llama_hparams { bool vocab_only; + bool no_alloc; bool rope_finetuned; bool use_par_res; bool swin_norm; @@ -109,6 +110,7 @@ struct llama_hparams { float rope_freq_base_train_swa; float rope_freq_scale_train; float rope_freq_scale_train_swa; + uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; @@ -166,6 +168,7 @@ struct llama_hparams { uint32_t n_no_rope_layer_step = 4; uint32_t n_attn_temp_floor_scale = 0; float f_attn_temp_scale = 0.0f; + float f_attn_temp_offset = 0.0f; // offset position index // gemma3n altup uint32_t n_altup = 4; // altup_num_inputs @@ -272,7 +275,8 @@ struct llama_hparams { // TODO: think of a better place for this function // TODO: pack the SWA params in a struct? static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); + + bool use_mrope() const; }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); - diff --git a/llama/llama.cpp/src/llama-impl.cpp b/llama/llama.cpp/src/llama-impl.cpp index c7a1880a..8e3e7b22 100644 --- a/llama/llama.cpp/src/llama-impl.cpp +++ b/llama/llama.cpp/src/llama-impl.cpp @@ -25,6 +25,10 @@ time_meas::~time_meas() { } } +void llama_log_get(ggml_log_callback * log_callback, void ** user_data) { + ggml_log_get(log_callback, user_data); +} + void llama_log_set(ggml_log_callback log_callback, void * user_data) { ggml_log_set(log_callback, user_data); g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default; diff --git a/llama/llama.cpp/src/llama-kv-cache.cpp b/llama/llama.cpp/src/llama-kv-cache.cpp index e26385a1..3186242d 100644 --- a/llama/llama.cpp/src/llama-kv-cache.cpp +++ b/llama/llama.cpp/src/llama-kv-cache.cpp @@ -175,7 +175,15 @@ llama_kv_cache::llama_kv_cache( // allocate tensors and initialize the buffers to avoid NaNs in the padding for (auto & [buft, ctx] : ctx_map) { - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); + ggml_backend_buffer_t buf; + if (model.hparams.no_alloc) { + buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer + for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) { + t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it + } + } else { + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); // real buffer + } if (!buf) { throw std::runtime_error("failed to allocate buffer for kv cache"); } @@ -482,9 +490,18 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { std::map llama_kv_cache::memory_breakdown() const { std::map ret; - for (const auto & [_, buf] : ctxs_bufs) { - ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); + for (const auto & [ctx, buf] : ctxs_bufs) { + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf.get()); + + if (hparams.no_alloc) { + GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) == nullptr); + ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft); + } else { + // GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base + ret[buft] += ggml_backend_buffer_get_size(buf.get()); + } } + return ret; } @@ -1232,8 +1249,7 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u GGML_ASSERT(n_tokens%n_stream == 0); // n_tps == n_tokens_per_stream - const int64_t n_tps = n_tokens/n_stream; - const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD); + const int64_t n_tps = n_tokens/n_stream; std::fill(data, data + ggml_nelements(dst), -INFINITY); @@ -1266,7 +1282,7 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; - const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii); + const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii); for (uint32_t j = 0; j < n_kv; ++j) { if (cells.is_empty(j)) { @@ -1370,9 +1386,10 @@ ggml_tensor * llama_kv_cache::build_rope_shift( float freq_scale) const { const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; - const auto & yarn_ext_factor = cparams.yarn_ext_factor; - const auto & yarn_beta_fast = cparams.yarn_beta_fast; - const auto & yarn_beta_slow = cparams.yarn_beta_slow; + const auto & yarn_ext_factor = cparams.yarn_ext_factor; + const auto & yarn_beta_fast = cparams.yarn_beta_fast; + const auto & yarn_beta_slow = cparams.yarn_beta_slow; + const auto & yarn_attn_factor = cparams.yarn_attn_factor; const auto & n_rot = hparams.n_rot; const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE @@ -1383,12 +1400,6 @@ ggml_tensor * llama_kv_cache::build_rope_shift( ? LLAMA_ROPE_TYPE_NEOX : hparams.rope_type; - // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. - // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 - ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) - : cparams.yarn_attn_factor; - ggml_tensor * tmp; if (ggml_is_quantized(cur->type)) { @@ -1550,9 +1561,11 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id]; + slot_info sinfo; + bool res = true; - res = res && state_read_meta(io, strm, cell_count, seq_id); - res = res && state_read_data(io, strm, cell_count); + res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id); + res = res && state_read_data(io, strm, cell_count, sinfo); if (!res) { if (seq_id == -1) { @@ -1691,7 +1704,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t } } -bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) { +bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) { auto & cells = v_cells[strm]; auto & head = v_heads[strm]; @@ -1728,7 +1741,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 ubatch.seq_id[i] = &dest_seq_id; } - const auto sinfo = find_slot(ubatch, true); + sinfo = find_slot(ubatch, false); if (sinfo.empty()) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; @@ -1738,20 +1751,16 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350 apply_ubatch(sinfo, ubatch); - const auto head_cur = sinfo.head(); + LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id); - // keep the head at the old position because we will read the KV data into it in state_read_data() - head = head_cur; - - LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id); - - // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(head_cur + cell_count <= cells.size()); - GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]); - GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]); - GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); - GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); + // DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values + GGML_ASSERT(sinfo.n_stream() == 1); + GGML_ASSERT(sinfo.idxs[0].size() == cell_count); + for (uint32_t i = 0; i < cell_count; ++i) { + const uint32_t idx = sinfo.idxs[0][i]; + GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]); + GGML_ASSERT(cells.seq_has(idx, dest_seq_id)); + } } else { // whole KV cache restore @@ -1784,15 +1793,24 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 } } + // Create contiguous slot_info for whole cache restore + sinfo.s0 = strm; + sinfo.s1 = strm; + sinfo.resize(1); + sinfo.strm[0] = strm; + sinfo.idxs[0].resize(cell_count); + for (uint32_t i = 0; i < cell_count; ++i) { + sinfo.idxs[0][i] = i; + } + head = 0; } return true; } -bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) { +bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) { auto & cells = v_cells[strm]; - auto & head = v_heads[strm]; uint32_t v_trans; uint32_t n_layer; @@ -1842,8 +1860,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 } if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells, single memcpy + ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row); + } else { + // Slow path: scatter to non-contiguous positions + const void * src = io.read(cell_count * k_size_row); + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_offset = sinfo.idxs[0][i] * k_size_row; + ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row); + } + } } } @@ -1874,8 +1901,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 } if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells, single memcpy + ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row); + } else { + // Slow path: scatter to non-contiguous positions + const void * src = io.read(cell_count * v_size_row); + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_offset = sinfo.idxs[0][i] * v_size_row; + ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row); + } + } } } } else { @@ -1914,10 +1950,22 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 } if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells + const uint32_t h = sinfo.head(); + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (h + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } else { + // Slow path: scatter to non-contiguous positions + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const void * src = io.read(cell_count * v_size_el); + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el); + } + } } } } diff --git a/llama/llama.cpp/src/llama-kv-cache.h b/llama/llama.cpp/src/llama-kv-cache.h index bf7821c0..1868f118 100644 --- a/llama/llama.cpp/src/llama-kv-cache.h +++ b/llama/llama.cpp/src/llama-kv-cache.h @@ -72,6 +72,23 @@ public: void clear() { idxs.clear(); } + + // check if indices are contiguous starting from head() + bool is_contiguous() const { + if (idxs.empty() || idxs[0].empty()) { + return true; + } + if (idxs.size() > 1) { + return false; + } + const uint32_t h = idxs[0][0]; + for (size_t i = 0; i < idxs[0].size(); ++i) { + if (idxs[0][i] != h + i) { + return false; + } + } + return true; + } }; using slot_info_vec_t = std::vector; @@ -264,8 +281,8 @@ private: void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; - bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); + bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo); }; class llama_kv_cache_context : public llama_memory_context_i { diff --git a/llama/llama.cpp/src/llama-memory-hybrid.cpp b/llama/llama.cpp/src/llama-memory-hybrid.cpp index dfb8439e..a1b45e4a 100644 --- a/llama/llama.cpp/src/llama-memory-hybrid.cpp +++ b/llama/llama.cpp/src/llama-memory-hybrid.cpp @@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), - ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } diff --git a/llama/llama.cpp/src/llama-model-loader.cpp b/llama/llama.cpp/src/llama-model-loader.cpp index ee303bd5..8916a624 100644 --- a/llama/llama.cpp/src/llama-model-loader.cpp +++ b/llama/llama.cpp/src/llama-model-loader.cpp @@ -473,6 +473,7 @@ llama_model_loader::llama_model_loader( std::vector & splits, bool use_mmap, bool check_tensors, + bool no_alloc, const llama_model_kv_override * param_overrides_p, const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { int trace = 0; @@ -716,6 +717,7 @@ llama_model_loader::llama_model_loader( this->use_mmap = use_mmap; this->check_tensors = check_tensors; + this->no_alloc = no_alloc; } std::string llama_model_loader::get_arch_name() const { diff --git a/llama/llama.cpp/src/llama-model-loader.h b/llama/llama.cpp/src/llama-model-loader.h index c9189f6c..0380c92f 100644 --- a/llama/llama.cpp/src/llama-model-loader.h +++ b/llama/llama.cpp/src/llama-model-loader.h @@ -71,6 +71,7 @@ struct llama_model_loader { bool use_mmap = false; bool check_tensors; + bool no_alloc; llama_files files; llama_ftype ftype; @@ -97,6 +98,7 @@ struct llama_model_loader { std::vector & splits, // optional, only need if the split does not follow naming scheme bool use_mmap, bool check_tensors, + bool no_alloc, const llama_model_kv_override * param_overrides_p, const llama_model_tensor_buft_override * param_tensor_buft_overrides_p); diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index 3c503b42..00cd579e 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -120,6 +120,8 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; + case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; + case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; @@ -667,6 +669,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_swa = 8192; hparams.n_attn_temp_floor_scale = 8192; hparams.f_attn_temp_scale = 0.1f; + hparams.f_attn_temp_offset = 1.0f; hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full } @@ -1634,12 +1637,19 @@ void llama_model::load_hparams(llama_model_loader & ml) { // that have no expert_gating_func model parameter set hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; } - ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + + if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // cancel the factor from the convert script + hparams.rope_yarn_log_mul /= 0.1f; + } // (optional) temperature tuning - used by mistral-large ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); + hparams.f_attn_temp_offset = 0.0f; + switch (hparams.n_layer) { case 27: type = LLM_TYPE_16B; break; case 60: type = LLM_TYPE_236B; break; @@ -1679,7 +1689,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GLM4: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); switch (hparams.n_layer) { case 40: type = LLM_TYPE_9B; break; case 61: type = LLM_TYPE_32B; break; @@ -1688,8 +1699,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GLM4_MOE: { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); // MoE parameters ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); @@ -1788,6 +1800,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: { ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); @@ -1803,7 +1816,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + switch (hparams.n_layer) { + case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B case 56: type = LLM_TYPE_9B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -2272,7 +2292,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } switch (hparams.n_layer) { - case 80: type = LLM_TYPE_80B_A3B; break; + case 48: type = LLM_TYPE_80B_A3B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -2281,9 +2301,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f); + + hparams.f_attn_temp_offset = 0.0f; // TODO: maybe add n_attn_temp_floor_scale as a separate KV? if (hparams.f_attn_temp_scale != 0.0f) { @@ -2293,18 +2315,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } - // TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f - // but may need further verification with other values - if (hparams.rope_yarn_log_mul != 0.0f) { - float factor = 1.0f / hparams.rope_freq_scale_train; - float mscale = 1.0f; - float mscale_all_dims = hparams.rope_yarn_log_mul; - static auto get_mscale = [](float scale, float mscale) { - return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); - }; - hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); - } - switch (hparams.n_layer) { case 26: type = LLM_TYPE_3B; break; case 34: type = LLM_TYPE_8B; break; @@ -3404,9 +3414,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -5175,6 +5185,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: { // mamba2 Mixer SSM params // NOTE: int64_t for tensor dimensions @@ -5185,6 +5196,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + // embeddings tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5234,12 +5248,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - } else { - // mlp layers - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } else { + if (n_expert != 0) { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + + } else { + // mlp layers + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } } } } break; @@ -6650,9 +6678,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { std::vector bufs; if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { + GGML_ASSERT(!ml.no_alloc); for (uint32_t idx = 0; idx < ml.files.size(); idx++) { // only the mmap region containing the tensors in the model is mapped to the backend buffer - // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers + // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, + // then we could just use metal for all layers // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size void * addr = nullptr; size_t first, last; // NOLINT @@ -6668,9 +6698,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { bufs.emplace_back(buf); buf_map.emplace(idx, buf); } - } - else { - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + } else { + ggml_backend_buffer_t buf; + if (ml.no_alloc) { + buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = buf; // set dummy buffer for weights so that the backend scheduler won't try to allocate them + } + } else { + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer + } if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } @@ -6725,6 +6762,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } + if (ml.no_alloc) { + return true; + } + // load tensor data for (auto & [ctx, buf_map] : ctx_buf_maps) { if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { @@ -6767,9 +6808,18 @@ size_t llama_model::n_devices() const { std::map llama_model::memory_breakdown() const { std::map ret; - for (const auto & [_, bufs] : pimpl->ctxs_bufs) { - for (const auto & buf : bufs) { - ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); + for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) { + if (hparams.no_alloc) { + GGML_ASSERT(bufs.size() == 1); + ggml_backend_buffer_t buf = bufs[0].get(); + GGML_ASSERT(ggml_backend_buffer_get_base(buf) == nullptr); + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf); + ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft); + } else { + for (const auto & buf : bufs) { + // GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); + } } } return ret; @@ -6814,6 +6864,7 @@ void llama_model::print_info() const { // hparams LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); + LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); if (!hparams.vocab_only) { LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); @@ -6848,6 +6899,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); // MRoPE (Multi-axis Rotary Position Embedding) sections if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { @@ -6870,7 +6922,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_QWEN3NEXT || - arch == LLM_ARCH_NEMOTRON_H) { + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_NEMOTRON_H_MOE) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); @@ -6911,7 +6964,6 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); } if (arch == LLM_ARCH_QWEN2MOE) { @@ -6926,7 +6978,8 @@ void llama_model::print_info() const { if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_HYBRID) { + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); @@ -7107,7 +7160,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, if (arch == LLM_ARCH_FALCON_H1) { filter_attn = [&](int32_t) { return true; }; filter_recr = [&](int32_t) { return true; }; - } else if (arch == LLM_ARCH_NEMOTRON_H) { + } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { filter_attn = [&](int32_t il) { return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; }; @@ -7478,6 +7531,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: { llm = std::make_unique(*this, params); } break; @@ -7666,6 +7720,7 @@ llama_model_params llama_model_default_params() { /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, /*.no_host =*/ false, + /*.no_alloc =*/ false, }; return result; @@ -7765,6 +7820,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -7785,7 +7841,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: - case LLM_ARCH_GLM4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_GRANITE_HYBRID: @@ -7848,7 +7903,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: case LLM_ARCH_SMALLTHINKER: - case LLM_ARCH_GLM4_MOE: case LLM_ARCH_SEED_OSS: case LLM_ARCH_GROVEMOE: case LLM_ARCH_APERTUS: @@ -7865,6 +7919,11 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3VLMOE: return LLAMA_ROPE_TYPE_IMROPE; + case LLM_ARCH_GLM4: + return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NORM; + case LLM_ARCH_GLM4_MOE: + return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: GGML_ABORT("unknown architecture"); diff --git a/llama/llama.cpp/src/llama-model.h b/llama/llama.cpp/src/llama-model.h index cbf4e1bf..b378b23e 100644 --- a/llama/llama.cpp/src/llama-model.h +++ b/llama/llama.cpp/src/llama-model.h @@ -114,6 +114,7 @@ enum llm_type { LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_31B_A3_5B, LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_106B_A12B, // GLM-4.5-Air diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp index 351dcb7b..bc4b05c3 100644 --- a/llama/llama.cpp/src/llama-quant.cpp +++ b/llama/llama.cpp/src/llama-quant.cpp @@ -596,7 +596,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr); + llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp index f72f321b..d63ce9c8 100644 --- a/llama/llama.cpp/src/llama-vocab.cpp +++ b/llama/llama.cpp/src/llama-vocab.cpp @@ -1884,7 +1884,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { clean_spaces = false; } else if ( tokenizer_pre == "qwen2" || - tokenizer_pre == "deepseek-r1-qwen") { + tokenizer_pre == "deepseek-r1-qwen" || + tokenizer_pre == "kormo") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; } else if ( diff --git a/llama/llama.cpp/src/llama.cpp b/llama/llama.cpp/src/llama.cpp index 74c49e65..759152b7 100644 --- a/llama/llama.cpp/src/llama.cpp +++ b/llama/llama.cpp/src/llama.cpp @@ -1,6 +1,9 @@ +#include "llama.h" + #include "llama-impl.h" #include "llama-chat.h" +#include "llama-context.h" #include "llama-mmap.h" #include "llama-vocab.h" #include "llama-model-loader.h" @@ -11,11 +14,14 @@ #include "ggml-backend.h" #include +#include +#include #include #include #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -37,6 +43,646 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty GGML_ABORT("fatal error"); } +struct llama_device_memory_data { + int64_t total; + int64_t free; + llama_memory_breakdown_data mb; +}; + +static std::vector llama_get_device_memory_data( + const char * path_model, const llama_model_params * mparams, const llama_context_params * cparams, + std::vector & devs, uint32_t & hp_ngl, uint32_t & hp_n_ctx_train, uint32_t & hp_n_expert, + const ggml_log_level log_level) { + struct user_data_t { + struct { + ggml_log_callback callback; + void * user_data; + } original_logger; + ggml_log_level min_level; // prints below this log level go to debug log + }; + user_data_t ud; + llama_log_get(&ud.original_logger.callback, &ud.original_logger.user_data); + ud.min_level = log_level; + + llama_log_set([](ggml_log_level level, const char * text, void * user_data) { + const user_data_t * ud = (const user_data_t *) user_data; + const ggml_log_level level_eff = level >= ud->min_level ? level : GGML_LOG_LEVEL_DEBUG; + ud->original_logger.callback(level_eff, text, ud->original_logger.user_data); + }, &ud); + + llama_model_params mparams_copy = *mparams; + mparams_copy.no_alloc = true; + mparams_copy.use_mmap = false; + + llama_model * model = llama_model_load_from_file(path_model, mparams_copy); + if (model == nullptr) { + llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); + throw std::runtime_error("failed to load model"); + } + + llama_context * ctx = llama_init_from_model(model, *cparams); + if (ctx == nullptr) { + llama_model_free(model); + llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); + throw std::runtime_error("failed to create llama_context from model"); + } + + std::vector ret(model->devices.size()); + + std::map memory_breakdown = ctx->memory_breakdown(); + + for (const auto & [buft, mb] : memory_breakdown) { + if (ggml_backend_buft_is_host(buft)) { + continue; + } + + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + if (!dev) { + continue; + } + for (size_t i = 0; i < ret.size(); i++) { + if (model->devices[i] == dev) { + ret[i].mb.model += mb.model; + ret[i].mb.context += mb.context; + ret[i].mb.compute += mb.compute; + break; + } + } + } + for (size_t i = 0; i < ret.size(); i++) { + size_t free, total; + ggml_backend_dev_memory(model->devices[i], &free, &total); + ret[i].free = free; + ret[i].total = total; + } + + devs = model->devices; + hp_ngl = model->hparams.n_layer; + hp_n_ctx_train = model->hparams.n_ctx_train; + hp_n_expert = model->hparams.n_expert; + + llama_memory_breakdown_print(ctx); // goes to debug log + + llama_free(ctx); + llama_model_free(model); + llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); + return ret; +} + +// enum to identify part of a layer for distributing its tensors: +enum layer_fraction_t { + LAYER_FRACTION_NONE = 0, // nothing + LAYER_FRACTION_ATTN = 1, // attention + LAYER_FRACTION_UP = 2, // attention + up + LAYER_FRACTION_GATE = 3, // attention + up + gate + LAYER_FRACTION_MOE = 4, // everything but sparse MoE weights +}; +// this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue + +static void llama_params_fit_impl( + const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, + float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, + size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { + constexpr int64_t MiB = 1024*1024; + const int64_t margin = margin_s; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits + typedef std::vector dmds_t; + const llama_model_params default_mparams = llama_model_default_params(); + + std::vector devs; + uint32_t hp_ngl = 0; // hparams.n_gpu_layers + uint32_t hp_nct = 0; // hparams.n_ctx_train + uint32_t hp_nex = 0; // hparams.n_expert + + // step 1: get data for default parameters and check whether any changes are necessary in the first place + + LLAMA_LOG_DEBUG("%s: getting device memory data for initial parameters:\n", __func__); + const dmds_t dmds_full = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); + const size_t nd = devs.size(); // number of devices + if (nd == 0) { + LLAMA_LOG_INFO("%s: no devices with dedicated memory found\n", __func__); + return; + } + + std::vector dev_names; + { + dev_names.reserve(nd); + size_t max_length = 0; + for (ggml_backend_dev_t dev : devs) { + std::string name = ggml_backend_dev_name(dev); + name += " ("; + name += ggml_backend_dev_description(dev); + name += ")"; + dev_names.push_back(name); + max_length = std::max(max_length, name.length()); + } + for (std::string & dn : dev_names) { + dn.insert(dn.end(), max_length - dn.length(), ' '); + } + } + + int64_t sum_total = 0; + int64_t sum_projected_free = 0; + int64_t min_projected_free = INT64_MAX; + int64_t sum_projected_used = 0; + int64_t sum_projected_ctx = 0; + + if (nd > 1) { + LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__); + } + for (size_t id = 0; id < nd; id++) { + const llama_device_memory_data & dmd = dmds_full[id]; + + const int64_t projected_used = dmd.mb.total(); + const int64_t projected_free = dmd.free - projected_used; + + sum_total += dmd.total; + sum_projected_used += projected_used; + sum_projected_free += projected_free; + min_projected_free = std::min(min_projected_free, projected_free); + sum_projected_ctx += dmd.mb.context; + + if (nd > 1) { + LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n", + __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, std::abs(projected_free)/MiB, + projected_free >= 0 ? "surplus" : "deficit"); + } + } + assert(sum_total >= 0 && sum_projected_used >= 0 && sum_projected_ctx >= 0); + assert(sum_projected_used >= sum_projected_ctx); + LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", + __func__, sum_projected_used/MiB, sum_total/MiB); + if (min_projected_free >= margin) { + if (nd == 1) { + LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", + __func__, min_projected_free/MiB, margin/MiB); + return; + } + LLAMA_LOG_INFO("%s: will leave at least %" PRId64 " >= %" PRId64 " MiB of free memory on all devices, no changes needed\n", + __func__, min_projected_free/MiB, margin/MiB); + return; + } + + // step 2: try reducing memory use by reducing the context size + + { + int64_t global_surplus = sum_projected_free - int64_t(nd)*margin; + if (global_surplus < 0) { + LLAMA_LOG_INFO(nd == 1 ? + "%s: cannot fulfill margin of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n" : + "%s: cannot fulfill margin of %" PRId64 " MiB on all devices, need to use %" PRId64 " MiB less in total\n", + __func__, margin/MiB, -global_surplus/MiB); + if (cparams->n_ctx == 0) { + if (hp_nct > n_ctx_min) { + const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct; + const uint32_t ctx_reduction = std::min( + uint32_t((-global_surplus + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min); + cparams->n_ctx = hp_nct - ctx_reduction; + const int64_t memory_reduction = ctx_reduction * bytes_per_ctx; + global_surplus += memory_reduction; + LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", + __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); + if (global_surplus >= 0) { + if (nd == 1) { + LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__); + return; + } + LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__); + } + } else { + LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n", + __func__, hp_nct, n_ctx_min); + } + } else { + LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx); + } + } + } + + if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) { + throw std::runtime_error("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort"); + } + if (nd > 1) { + if (!tensor_split) { + throw std::runtime_error("did not provide a buffer to write the tensor_split to, abort"); + } + if (mparams->tensor_split) { + for (size_t id = 0; id < nd; id++) { + if (mparams->tensor_split[id] != 0.0f) { + throw std::runtime_error("model_params::tensor_split already set by user, abort"); + } + } + } + if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) { + throw std::runtime_error("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort"); + } + if (hp_ngl < 2*nd) { + throw std::runtime_error("model has only " + std::to_string(hp_ngl) + " layers but need at least " + + std::to_string(2*nd) + " to fit memory for " + std::to_string(nd) + " devices, abort"); + } + } + if (!tensor_buft_overrides) { + throw std::runtime_error("did not provide buffer to set tensor_buft_overrides, abort"); + } + if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) { + throw std::runtime_error("model_params::tensor_buft_overrides already set by user, abort"); + } + + // step 3: iteratively fill the back to front with "dense" layers + // - for a dense model simply fill full layers, giving each device a contiguous slice of the model + // - for a MoE model, same as dense model but with all MoE tensors in system memory + + // utility function that returns a static C string matching the tensors for a specific layer index and layer fraction: + auto get_overflow_pattern = [&](const size_t il, const layer_fraction_t lf) -> const char * { + constexpr size_t n_strings = 1000; + if (il >= n_strings) { + throw std::runtime_error("at most " + std::to_string(n_strings) + " model layers are supported"); + } + switch (lf) { + case LAYER_FRACTION_ATTN: { + static std::array patterns; + if (patterns[il].empty()) { + patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|gate|down).*"; + } + return patterns[il].c_str(); + } + case LAYER_FRACTION_UP: { + static std::array patterns; + if (patterns[il].empty()) { + patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|down).*"; + } + return patterns[il].c_str(); + } + case LAYER_FRACTION_GATE: { + static std::array patterns; + if (patterns[il].empty()) { + patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_down.*"; + } + return patterns[il].c_str(); + } + case LAYER_FRACTION_MOE: { + static std::array patterns; + if (patterns[il].empty()) { + patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|down|gate)_(ch|)exps"; + } + return patterns[il].c_str(); + } + default: + GGML_ABORT("fatal error"); + } + }; + + struct ngl_t { + uint32_t n_layer = 0; // number of total layers + uint32_t n_part = 0; // number of partial layers, <= n_layer + + // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE: + layer_fraction_t overflow_type = LAYER_FRACTION_MOE; + }; + + const size_t ntbo = llama_max_tensor_buft_overrides(); + + // utility function to set n_gpu_layers and tensor_split + auto set_ngl_tensor_split_tbo = [&]( + const std::vector & ngl_per_device, + const std::vector & overflow_bufts, + llama_model_params & mparams, + const bool add_nonrepeating) { + mparams.n_gpu_layers = 0; + for (size_t id = 0; id < nd; id++) { + mparams.n_gpu_layers += ngl_per_device[id].n_layer; + if (nd > 1) { + tensor_split[id] = ngl_per_device[id].n_layer; + } + } + assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl); + uint32_t il0 = hp_ngl - mparams.n_gpu_layers; // start index for tensor buft overrides + + if (add_nonrepeating) { + mparams.n_gpu_layers += 1; + tensor_split[nd - 1] += 1; + } + mparams.tensor_split = tensor_split; + + size_t itbo = 0; + for (size_t id = 0; id < nd; id++) { + il0 += ngl_per_device[id].n_layer - ngl_per_device[id].n_part; + for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) { + if (itbo + 1 >= ntbo) { + tensor_buft_overrides[itbo].pattern = nullptr; + tensor_buft_overrides[itbo].buft = nullptr; + itbo++; + mparams.tensor_buft_overrides = tensor_buft_overrides; + throw std::runtime_error("llama_params_fit_n_tensor_buft_overrides() == " + + std::to_string(ntbo) + " is insufficient for model\n"); + } + tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE); + tensor_buft_overrides[itbo].buft = overflow_bufts[id]; + itbo++; + } + il0 += ngl_per_device[id].n_part; + } + tensor_buft_overrides[itbo].pattern = nullptr; + tensor_buft_overrides[itbo].buft = nullptr; + itbo++; + mparams.tensor_buft_overrides = tensor_buft_overrides; + }; + + // utility function that returns the memory use per device for given numbers of layers per device + auto get_memory_for_layers = [&]( + const char * func_name, + const std::vector & ngl_per_device, + const std::vector & overflow_bufts, + const bool add_nonrepeating) -> std::vector { + llama_model_params mparams_copy = *mparams; + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy, add_nonrepeating); + + const dmds_t dmd_nl = llama_get_device_memory_data( + path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); + + LLAMA_LOG_DEBUG("%s: memory for test allocation by device:\n", func_name); + for (size_t id = 0; id < nd; id++) { + const ngl_t & n = ngl_per_device[id]; + LLAMA_LOG_DEBUG( + "%s: id=%zu, n_layer=%2" PRIu32 ", n_part=%2" PRIu32 ", overflow_type=%d, mem=%6" PRId64 " MiB\n", + func_name, id, n.n_layer, n.n_part, int(n.overflow_type), dmd_nl[id].mb.total()/MiB); + } + + std::vector ret; + ret.reserve(nd); + for (const llama_device_memory_data & dmd : dmd_nl) { + ret.push_back(dmd.mb.total()); + } + return ret; + }; + + int64_t global_surplus_cpu_moe = 0; + if (hp_nex > 0) { + const static std::string pattern_moe_all = "blk\\.\\d+\\.ffn_(up|down|gate)_(ch|)exps"; // matches all MoE tensors + ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type(); + tensor_buft_overrides[0] = {pattern_moe_all.c_str(), cpu_buft}; + tensor_buft_overrides[1] = {nullptr, nullptr}; + mparams->tensor_buft_overrides = tensor_buft_overrides; + + LLAMA_LOG_DEBUG("%s: getting device memory data with all MoE tensors moved to system memory:\n", __func__); + const dmds_t dmds_cpu_moe = llama_get_device_memory_data( + path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); + + for (const llama_device_memory_data & dmd : dmds_cpu_moe) { + global_surplus_cpu_moe += dmd.free; + global_surplus_cpu_moe -= int64_t(dmd.mb.total()) + margin; + } + + if (global_surplus_cpu_moe > 0) { + LLAMA_LOG_INFO("%s: with only dense weights in device memory there is a total surplus of %" PRId64 " MiB\n", + __func__, global_surplus_cpu_moe/MiB); + } else { + LLAMA_LOG_INFO("%s: with only dense weights in device memory there is still a total deficit of %" PRId64 " MiB\n", + __func__, -global_surplus_cpu_moe/MiB); + } + + // reset + tensor_buft_overrides[0] = {nullptr, nullptr}; + mparams->tensor_buft_overrides = tensor_buft_overrides; + } + + std::vector targets; // maximum acceptable memory use per device + targets.reserve(nd); + for (size_t id = 0; id < nd; id++) { + targets.push_back(dmds_full[id].free - margin); + LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); + } + + // whether for the optimal memory use we expect to load at least some MoE tensors: + const bool partial_moe = hp_nex > 0 && global_surplus_cpu_moe > 0; + + std::vector overflow_bufts; // which bufts the partial layers of a device overflow to: + overflow_bufts.reserve(nd); + for (size_t id = 0; id < nd - 1; ++id) { + overflow_bufts.push_back(ggml_backend_dev_buffer_type(devs[id + 1])); + } + overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); + + std::vector ngl_per_device(nd); + std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts, partial_moe); + if (hp_nex > 0) { + for (size_t id = 0; id < nd; id++) { + ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE; + } + } + + // optimize the number of layers per device using the method of false position: + // - ngl_per_device has 0 layers for each device, lower bound + // - try a "high" configuration where a device is given all unassigned layers + // - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target + // - check memory use of our guess, replace either the low or high bound + // - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits + if (hp_nex == 0) { + LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__); + } else { + LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__); + } + uint32_t n_unassigned = hp_ngl; + for (int id = nd - 1; id >= 0; id--) { + std::vector ngl_per_device_high = ngl_per_device; + ngl_per_device_high[id].n_layer = n_unassigned; + if (hp_nex > 0) { + ngl_per_device_high[id].n_part = ngl_per_device_high[id].n_layer; + } + if (ngl_per_device_high[id].n_layer > 0) { + std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe); + if (mem_high[id] > targets[id]) { + uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; + while (delta > 1) { + uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); + step_size = std::max(step_size, uint32_t(1)); + step_size = std::min(step_size, delta - 1); + + std::vector ngl_per_device_test = ngl_per_device; + ngl_per_device_test[id].n_layer += step_size; + if (hp_nex) { + ngl_per_device_test[id].n_part += step_size; + } + const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); + + if (mem_test[id] <= targets[id]) { + ngl_per_device = ngl_per_device_test; + mem = mem_test; + n_unassigned -= ngl_per_device[id].n_layer; + LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); + } else { + ngl_per_device_high = ngl_per_device_test; + mem_high = mem_test; + LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); + } + delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; + } + } else { + ngl_per_device = ngl_per_device_high; + n_unassigned -= ngl_per_device[id].n_layer; + LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); + } + } + + const int64_t projected_margin = dmds_full[id].free - mem[id]; + LLAMA_LOG_INFO( + "%s: - %s: %2" PRIu32 " layers, %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", + __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB); + } + if (hp_nex == 0 || global_surplus_cpu_moe <= 0) { + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe); + return; + } + + // step 4: for a MoE model where all dense tensors fit, + // convert the dense-only layers in the back to full layers in the front until all devices are full + // essentially the same procedure as for the dense-only layers except front-to-back + // also, try fitting at least part of one more layer to reduce waste for "small" GPUs with e.g. 24 GiB VRAM + + size_t id_dense_start = nd; + for (int id = nd - 1; id >= 0; id--) { + if (ngl_per_device[id].n_layer > 0) { + id_dense_start = id; + continue; + } + break; + } + assert(id_dense_start < nd); + + LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__); + for (size_t id = 0; id <= id_dense_start; id++) { + std::vector ngl_per_device_high = ngl_per_device; + for (size_t jd = id_dense_start; jd < nd; jd++) { + const uint32_t n_layer_move = ngl_per_device_high[jd].n_layer; + ngl_per_device_high[id].n_layer += n_layer_move; + ngl_per_device_high[jd].n_layer -= n_layer_move; + ngl_per_device_high[jd].n_part = 0; + } + size_t id_dense_start_high = nd - 1; + std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe); + + if (mem_high[id] > targets[id]) { + assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part); + assert(ngl_per_device[id].n_layer >= ngl_per_device[id].n_part); + assert((ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) + >= ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + uint32_t delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) + - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + while (delta > 1) { + uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); + step_size = std::max(step_size, uint32_t(1)); + step_size = std::min(step_size, delta - 1); + + std::vector ngl_per_device_test = ngl_per_device; + size_t id_dense_start_test = id_dense_start; + uint32_t n_converted_test = 0; + for (;id_dense_start_test < nd; id_dense_start_test++) { + const uint32_t n_convert_jd = std::min(step_size - n_converted_test, ngl_per_device_test[id_dense_start_test].n_part); + ngl_per_device_test[id_dense_start_test].n_layer -= n_convert_jd; + ngl_per_device_test[id_dense_start_test].n_part -= n_convert_jd; + ngl_per_device_test[id].n_layer += n_convert_jd; + n_converted_test += n_convert_jd; + + if (ngl_per_device_test[id_dense_start_test].n_layer > 0) { + break; + } + } + const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); + + if (mem_test[id] <= targets[id]) { + ngl_per_device = ngl_per_device_test; + mem = mem_test; + id_dense_start = id_dense_start_test; + LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", + __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); + } else { + ngl_per_device_high = ngl_per_device_test; + mem_high = mem_test; + id_dense_start_high = id_dense_start_test; + LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n", + __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high); + } + delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) + - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + } + } else { + ngl_per_device = ngl_per_device_high; + id_dense_start = id_dense_start_high; + LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", + __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); + } + + // try to fit at least part of one more layer + if (ngl_per_device[id_dense_start].n_layer > 0) { + std::vector ngl_per_device_test = ngl_per_device; + size_t id_dense_start_test = id_dense_start; + ngl_per_device_test[id_dense_start_test].n_layer--; + ngl_per_device_test[id_dense_start_test].n_part--; + ngl_per_device_test[id].n_layer++; + ngl_per_device_test[id].n_part++; + if (ngl_per_device_test[id_dense_start_test].n_layer == 0) { + id_dense_start_test++; + } + ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; + LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); + std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); + if (mem_test[id] < targets[id]) { + ngl_per_device = ngl_per_device_test; + mem = mem_test; + id_dense_start = id_dense_start_test; + LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n", + __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); + + ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; + LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); + if (mem_test[id] < targets[id]) { + ngl_per_device = ngl_per_device_test; + mem = mem_test; + id_dense_start = id_dense_start_test; + LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n", + __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); + } + } else { + ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; + LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); + if (mem_test[id] < targets[id]) { + ngl_per_device = ngl_per_device_test; + mem = mem_test; + id_dense_start = id_dense_start_test; + LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n", + __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); + } + } + } + + const int64_t projected_margin = dmds_full[id].free - mem[id]; + LLAMA_LOG_INFO( + "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", + __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); + } + + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe); +} + +bool llama_params_fit( + const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, + float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, + size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { + const int64_t t0_us = llama_time_us(); + bool ok = true; + try { + llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level); + LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__); + } catch (const std::runtime_error & e) { + LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what()); + ok = false; + } + const int64_t t1_us = llama_time_us(); + LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6); + return ok; +} + struct llama_sampler_chain_params llama_sampler_chain_default_params() { struct llama_sampler_chain_params result = { /*.no_perf =*/ true, @@ -49,6 +695,10 @@ size_t llama_max_devices(void) { return 16; } +size_t llama_max_tensor_buft_overrides() { + return 4096; +} + bool llama_supports_mmap(void) { return llama_mmap::SUPPORTED; } @@ -108,11 +758,12 @@ static int llama_model_load(const std::string & fname, std::vector model.t_start_us = tm.t_start_us; try { - llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.tensor_buft_overrides); + llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); ml.print_info(); model.hparams.vocab_only = params.vocab_only; + model.hparams.no_alloc = params.no_alloc; try { model.load_arch(ml); diff --git a/llama/llama.cpp/src/models/deepseek2.cpp b/llama/llama.cpp/src/models/deepseek2.cpp index dbaa8297..49382874 100644 --- a/llama/llama.cpp/src/models/deepseek2.cpp +++ b/llama/llama.cpp/src/models/deepseek2.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B @@ -20,9 +18,15 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); - const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k)); - const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + + // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor + GGML_ASSERT(ext_factor >= 0.0f); + const float attn_factor_org = attn_factor * (1.0f + 0.1f * logf(1.0f / freq_scale)); + + // use the original attn_factor to pre-scale the kq_scale + const float mscale = attn_factor_org * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k)); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/llama/llama.cpp/src/models/glm4-moe.cpp b/llama/llama.cpp/src/models/glm4-moe.cpp index 33ee7070..003f70f7 100644 --- a/llama/llama.cpp/src/models/glm4-moe.cpp +++ b/llama/llama.cpp/src/models/glm4-moe.cpp @@ -5,11 +5,20 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + ggml_tensor * cur; ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); + bool use_mrope = hparams.use_mrope(); + if (ubatch.embd && !use_mrope) { + // unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results + GGML_ABORT("This GGUF does not support multimodal. Please reconvert it."); + } + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -60,17 +69,25 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); cb(Kcur, "Kcur_normed", il); } - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + if (use_mrope) { + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } else { + // Normal RoPE + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, + rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, + rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); diff --git a/llama/llama.cpp/src/models/glm4.cpp b/llama/llama.cpp/src/models/glm4.cpp index f789b282..204aa393 100644 --- a/llama/llama.cpp/src/models/glm4.cpp +++ b/llama/llama.cpp/src/models/glm4.cpp @@ -8,11 +8,20 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + ggml_tensor * cur; ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); + bool use_mrope = hparams.use_mrope(); + if (ubatch.embd && !use_mrope) { + // unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results + GGML_ABORT("This GGUF does not support multimodal. Please reconvert it."); + } + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -63,11 +72,25 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); } - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); + if (use_mrope) { + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } else { + // Normal RoPE + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, + rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, + rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); diff --git a/llama/llama.cpp/src/models/models.h b/llama/llama.cpp/src/models/models.h index e0aec822..6d84a185 100644 --- a/llama/llama.cpp/src/models/models.h +++ b/llama/llama.cpp/src/models/models.h @@ -441,23 +441,13 @@ private: ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, int il); ggml_tensor * build_layer_ffn( ggml_tensor * cur, int il); - ggml_tensor * build_delta_net_recurrent( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - int il); - ggml_tensor * build_delta_net_chunking( ggml_tensor * q, ggml_tensor * k, @@ -467,8 +457,18 @@ private: ggml_tensor * state, ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, int il); + ggml_tensor * build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il); + ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, diff --git a/llama/llama.cpp/src/models/nemotron-h.cpp b/llama/llama.cpp/src/models/nemotron-h.cpp index 54143488..eb135e63 100644 --- a/llama/llama.cpp/src/models/nemotron-h.cpp +++ b/llama/llama.cpp/src/models/nemotron-h.cpp @@ -107,12 +107,41 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * } ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) { - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - NULL, NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * ffn_inp = cur; + ggml_tensor * moe_out = + build_moe_ffn(ffn_inp, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + nullptr, // no gate + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_RELU_SQR, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); + cb(moe_out, "ffn_moe_out", il); + + ggml_tensor * ffn_shexp = build_ffn(ffn_inp, + model.layers[il].ffn_up_shexp, NULL, NULL, + NULL /* no gate */ , NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } cur = build_cvec(cur, il); cb(cur, "l_out", il); diff --git a/llama/llama.cpp/src/models/qwen2.cpp b/llama/llama.cpp/src/models/qwen2.cpp index 587a9324..3da4dea3 100644 --- a/llama/llama.cpp/src/models/qwen2.cpp +++ b/llama/llama.cpp/src/models/qwen2.cpp @@ -31,16 +31,25 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para { // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); diff --git a/llama/llama.cpp/src/models/qwen3next.cpp b/llama/llama.cpp/src/models/qwen3next.cpp index c8f1b5ec..775b3135 100644 --- a/llama/llama.cpp/src/models/qwen3next.cpp +++ b/llama/llama.cpp/src/models/qwen3next.cpp @@ -17,13 +17,15 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr ggml_tensor * inp_out_ids = build_inp_out_ids(); ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f), + ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), GGML_TRI_TYPE_LOWER); - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f)); + ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); + ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); ggml_build_forward_expand(gf, causal_mask); ggml_build_forward_expand(gf, identity); + ggml_build_forward_expand(gf, diag_mask); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -34,7 +36,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) - cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, il); + cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); } else { // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il); @@ -93,14 +95,8 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * state, ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, int il) { - GGML_ASSERT(ggml_is_contiguous(q)); - GGML_ASSERT(ggml_is_contiguous(k)); - GGML_ASSERT(ggml_is_contiguous(v)); - GGML_ASSERT(ggml_is_contiguous(g)); - GGML_ASSERT(ggml_is_contiguous(beta)); - GGML_ASSERT(ggml_is_contiguous(state)); - const int64_t S_k = q->ne[0]; const int64_t H_k = q->ne[1]; const int64_t n_tokens = q->ne[2]; @@ -120,15 +116,10 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - // TODO: can this ever be false? - const bool use_qk_l2norm = true; + const float eps_norm = hparams.f_norm_rms_eps; - if (use_qk_l2norm) { - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - } + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); const float scale = 1.0f / sqrtf(S_v); @@ -136,8 +127,6 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( beta = ggml_sigmoid(ctx0, beta); - ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity); - cb(q, "q_in", il); cb(k, "k_in", il); cb(v, "v_in", il); @@ -188,36 +177,21 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( cb(v_beta, "v_beta", il); cb(k_beta, "k_beta", il); - ggml_tensor * chunked_mask = - ggml_view_4d(ctx0, causal_mask, chunk_size, - chunk_size, causal_mask->ne[2], causal_mask->ne[3], - causal_mask->nb[1], causal_mask->nb[2], causal_mask->nb[3], 0); + q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); + v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - ggml_tensor * chunked_diag_mask = - ggml_view_4d(ctx0, causal_diag_mask, chunk_size, - chunk_size, causal_diag_mask->ne[2], causal_diag_mask->ne[3], - causal_diag_mask->nb[1], causal_diag_mask->nb[2], causal_diag_mask->nb[3], 0); - - ggml_tensor * chunked_identity = - ggml_view_4d(ctx0, identity, chunk_size, - chunk_size, identity->ne[2], identity->ne[3], - identity->nb[1], identity->nb[2], identity->nb[3], 0); - - q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); - k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); - v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); - v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - - g = ggml_cont_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); + g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); cb(g_cumsum, "g_cumsum", il); - ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); @@ -226,23 +200,23 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( cb(decay_mask, "decay_mask", il); - decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, chunked_mask)); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); cb(attn, "attn_pre_solve", il); - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, chunked_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, chunked_identity, attn_lower), attn_lower); + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, chunked_mask); - attn = ggml_add(ctx0, attn, chunked_identity); + attn = ggml_mul(ctx0, lin_solve, causal_mask); + attn = ggml_add(ctx0, attn, identity); cb(attn, "attn_solved", il); @@ -291,7 +265,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) attn = ggml_mul_mat(ctx0, k_chunk, q_chunk); attn = ggml_mul(ctx0, attn, decay_mask_chunk); - attn = ggml_mul(ctx0, attn, ggml_add(ctx0, chunked_identity, chunked_mask)); + attn = ggml_mul(ctx0, attn, diag_mask); ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); @@ -361,23 +335,14 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( return ggml_concat(ctx0, flat_output, flat_state, 0); } -ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent( +ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, int il) { - GGML_ASSERT(ggml_is_contiguous(q)); - GGML_ASSERT(ggml_is_contiguous(k)); - GGML_ASSERT(ggml_is_contiguous(v)); - GGML_ASSERT(ggml_is_contiguous(g)); - GGML_ASSERT(ggml_is_contiguous(beta)); - GGML_ASSERT(ggml_is_contiguous(state)); - const int64_t S_k = q->ne[0]; const int64_t H_k = q->ne[1]; const int64_t n_tokens = q->ne[2]; @@ -386,6 +351,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent( const int64_t S_v = v->ne[0]; const int64_t H_v = v->ne[1]; + GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing GGML_ASSERT(v->ne[2] == n_tokens); GGML_ASSERT(k->ne[2] == n_tokens); GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); @@ -397,215 +363,65 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent( GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - // TODO: can this ever be false? - const bool use_qk_l2norm = true; + const float eps_norm = hparams.f_norm_rms_eps; - if (use_qk_l2norm) { - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - } + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); const float scale = 1.0f / sqrtf(S_v); - q = ggml_scale(ctx0, q, scale); - + q = ggml_scale(ctx0, q, scale); beta = ggml_sigmoid(ctx0, beta); - ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity); - cb(q, "q_in", il); cb(k, "k_in", il); cb(v, "v_in", il); cb(beta, "beta_in", il); cb(g, "g_in", il); - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(g, "g_perm", il); - cb(state, "state_in", il); + ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); + ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + // Apply exponential to g_t + g_t = ggml_exp(ctx0, g_t); - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + // Apply the gated delta rule for the single timestep + // last_recurrent_state = last_recurrent_state * g_t + state = ggml_mul(ctx0, state, g_t); - ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); + ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); + // we need to sum over dim=-2, so we transpose, sum, then transpose again + kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); - cb(k_beta, "k_beta", il); - cb(v_beta, "v_beta", il); - cb(g_cumsum, "g_cumsum", il); + // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) + ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); + // delta = (v_t - kv_mem) * beta_t + ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] + ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, n_tokens, 1, H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs] - ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, n_tokens, H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs] + // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta + ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); + state = ggml_add(ctx0, state, k_t_delta); - // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs] - // ggml_tensor * gcs_i_broadcast = - // ggml_repeat_4d(ctx0, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, - // n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs] - // Don't need this, this one will get auto-broadcast - ggml_tensor * gcs_j_broadcast = - ggml_repeat_4d(ctx0, gcs_j, n_tokens, n_tokens, H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs] - - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - - // Apply lower triangular mask to ensure attention is causal (only past tokens influence current) - decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask); - // Apply exponential to get the decay mask values - decay_mask = ggml_exp(ctx0, decay_mask); - // Apply lower triangular mask again to ensure only lower triangular values remain - decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask); - - cb(decay_mask, "decay_mask", il); - - // attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); - - cb(kmulkbeta, "kmulkbeta", il); - - ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - - cb(attn, "attn_pre_rec", il); - - // for i in range(1, chunk_size): - // row = attn[..., i, :i].clone() - // sub = attn[..., :i, :i].clone() - // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - // - // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A) - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, causal_mask); - attn = ggml_add(ctx0, attn, identity); - - // value = attn @ v_beta - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); - - cb(v, "value_beta", il); - - // k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); - - cb(gexp, "g_cum_exp", il); - - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - - cb(kbeta_gexp, "kbeta_gexp", il); - - ggml_tensor * k_cumdecay = - ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); - - cb(k_cumdecay, "k_cumdecay", il); - - // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - attn = ggml_mul_mat(ctx0, k, q); - attn = ggml_mul(ctx0, attn, decay_mask); - attn = ggml_mul(ctx0, attn, ggml_add(ctx0, identity, causal_mask)); - - cb(attn, "attn_decay_key", il); - - ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); - - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay); - - cb(v_prime, "v_prime", il); - - // v_new = v_i - v_prime - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v, v_prime), v_prime); - - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - - cb(v_new, "v_new", il); - - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q, gexp); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); - - cb(attn_inter, "attn_inter", il); - - // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn); - - cb(v_attn, "v_attn", il); - - ggml_tensor * core_attn_out = ggml_add(ctx0, attn_inter, v_attn); - - cb(core_attn_out, "core_attn_out", il); - - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - ggml_tensor * g_cum_last = - ggml_cont(ctx0, ggml_view_4d(ctx0, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3], - g_cumsum_t->nb[1], g_cumsum_t->nb[2], g_cumsum_t->nb[3], - g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1))); - - cb(g_cum_last, "g_cum_last", il); - - ggml_tensor * gexp_last = - ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); - - cb(gexp_last, "gexp_last", il); - - ggml_tensor * g_cum_last_3d = - ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); - - cb(g_cum_last_3d, "g_cum_last_3d", il); - - ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]); - - cb(g_cumsum_3d, "g_cumsum_3d", il); - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d)); - - cb(g_diff, "g_diff", il); - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - - cb(g_diff_exp, "g_diff_exp", il); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, - ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], - g_diff_exp->ne[2] * g_diff_exp->ne[3])); - - cb(key_gdiff, "key_gdiff", il); - - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff))); - - cb(kgdmulvnew, "kgdmulvnew", il); - - state = ggml_add(ctx0, ggml_mul(ctx0, state, gexp_last), kgdmulvnew); + // Compute the attention output + // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t + ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); + // again, since it's over dim = -2, transpose, sum, transpose back + ggml_tensor * core_attn_out = + ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); + // core_attn_out should be [S_v, 1, H_v, n_seqs] after this + cb(core_attn_out, "output_tokens", il); cb(state, "new_state", il); - // flatten output - ggml_tensor * flat_output = - ggml_cont_1d(ctx0, ggml_permute(ctx0, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); - - ggml_tensor * flat_state = ggml_cont_1d(ctx0, state, S_v * S_v * H_v * n_seqs); + // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise + ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs); + ggml_tensor * flat_state = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs); return ggml_concat(ctx0, flat_output, flat_state, 0); } @@ -712,6 +528,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * cur, ggml_tensor * causal_mask, ggml_tensor * identity, + ggml_tensor * diag_mask, int il) { const auto * mctx_cur = inp->mctx; @@ -737,11 +554,11 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(mixed_ba, "linear_attn_mixed_ba", il); int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); - ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); + ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads] int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; - ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs); + ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs); // Split mixed_ba into b and a (beta and alpha parameters) int64_t split_sizes_ba[2] = { @@ -762,8 +579,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs); ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); - GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba)); - ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); cb(alpha_softplus, "a_softplus", il); @@ -799,9 +614,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)); cb(z, "z", il); - GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) + ggml_nelements(z) == - ggml_nelements(mixed_qkvz)); - // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); @@ -925,10 +737,13 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens - ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ? - build_delta_net_chunking (q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il) : - build_delta_net_recurrent(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il); + // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens + ggml_tensor * attn_out; + if (n_seq_tokens == 1) { + attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); + } else { + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + } cb(attn_out, "attn_out", il); // The tensors were concatenated 1d, so we need to extract them 1d as well diff --git a/llama/llama.cpp/tools/mtmd/clip-graph.h b/llama/llama.cpp/tools/mtmd/clip-graph.h new file mode 100644 index 00000000..2b191577 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/clip-graph.h @@ -0,0 +1,121 @@ +#pragma once + +#include "ggml.h" +#include "ggml-cpp.h" +#include "clip.h" +#include "clip-impl.h" +#include "clip-model.h" + +#include +#include + +#define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS) + +struct clip_graph { + const clip_model & model; + const clip_hparams & hparams; + projector_type proj_type; + + // we only support single image per batch + const clip_image_f32 & img; + + const int patch_size; + const int n_patches_x; + const int n_patches_y; + const int n_patches; + const int n_embd; + const int n_head; + const int d_head; + const int n_layer; + const int n_mmproj_embd; + const float eps; + const float kq_scale; + const clip_flash_attn_type flash_attn_type; + + // for debugging + const bool debug_graph; + std::vector & debug_print_tensors; + + ggml_context_ptr ctx0_ptr; + ggml_context * ctx0; + ggml_cgraph * gf; + + clip_graph(clip_ctx * ctx, const clip_image_f32 & img); + + virtual ~clip_graph() = default; + virtual ggml_cgraph * build() = 0; + + // + // utility functions + // + void cb(ggml_tensor * cur0, const char * name, int il) const; + + // siglip2 naflex + ggml_tensor * resize_position_embeddings(uint32_t interpolation_mode = DEFAULT_INTERPOLATION_MODE); + + // build vision transformer (ViT) cgraph + // this function should cover most of the models + // if your model has specific features, you should probably duplicate this function + ggml_tensor * build_vit( + ggml_tensor * inp, + int64_t n_pos, + norm_type norm_t, + ffn_op_type ffn_t, + ggml_tensor * learned_pos_embd, + std::function add_pos); + + // build the input after conv2d (inp_raw --> patches) + // returns tensor with shape [n_embd, n_patches] + ggml_tensor * build_inp(); + + ggml_tensor * build_inp_raw(int channels = 3); + + ggml_tensor * build_norm( + ggml_tensor * cur, + ggml_tensor * mw, + ggml_tensor * mb, + norm_type type, + float norm_eps, + int il) const; + + ggml_tensor * build_ffn( + ggml_tensor * cur, + ggml_tensor * up, + ggml_tensor * up_b, + ggml_tensor * gate, + ggml_tensor * gate_b, + ggml_tensor * down, + ggml_tensor * down_b, + ffn_op_type type_op, + int il) const; + + ggml_tensor * build_attn( + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_mask, + float kq_scale, + int il) const; + + // implementation of the 2D RoPE without adding a new op in ggml + // this is not efficient (use double the memory), but works on all backends + // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065 + ggml_tensor * build_rope_2d( + ggml_context * ctx0, + ggml_tensor * cur, + ggml_tensor * pos_a, // first half + ggml_tensor * pos_b, // second half + const float freq_base, + const bool interleave_freq + ); + + // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL) + // support dynamic resolution + ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor); + + // Generic function to stack frames for audio processing + // Abstracts out the StackAudioFrames logic used by ultravox + ggml_tensor * build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed); +}; diff --git a/llama/llama.cpp/tools/mtmd/clip-impl.h b/llama/llama.cpp/tools/mtmd/clip-impl.h index cd47865b..d75233cc 100644 --- a/llama/llama.cpp/tools/mtmd/clip-impl.h +++ b/llama/llama.cpp/tools/mtmd/clip-impl.h @@ -1,3 +1,5 @@ +#pragma once + #include "ggml.h" #include "gguf.h" #include "clip.h" @@ -13,6 +15,8 @@ // Internal header for clip.cpp +#define MTMD_INTERNAL_HEADER + #define KEY_FTYPE "general.file_type" #define KEY_NAME "general.name" #define KEY_DESCRIPTION "general.description" @@ -64,6 +68,7 @@ #define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat #define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" #define TN_PATCH_BIAS "v.patch_embd.bias" +#define TN_NORM_EMBD "v.norm_embd.%s" #define TN_ATTN_QKV "%s.blk.%d.attn_qkv.%s" #define TN_ATTN_K "%s.blk.%d.attn_k.%s" #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" @@ -82,6 +87,10 @@ #define TN_LN_PRE "%s.pre_ln.%s" #define TN_LN_POST "%s.post_ln.%s" #define TN_LLAVA_PROJ "mm.%d.%s" +#define TN_MM_UP "mm.up.%s" +#define TN_MM_GATE "mm.gate.%s" +#define TN_MM_DOWN "mm.down.%s" +#define TN_MM_POST_NORM "mm.post_norm.%s" #define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s" #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" @@ -91,7 +100,7 @@ #define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 #define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 #define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3 -#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1 +#define TN_MM_PATCH_MERGER "mm.patch_merger.%s" // mistral small 3.1, glm4v #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral #define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model) #define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model) @@ -132,6 +141,10 @@ // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) +// forward declaration +// TODO: improve this later +struct clip_ctx; + enum projector_type { PROJECTOR_TYPE_MLP, PROJECTOR_TYPE_MLP_NORM, @@ -149,6 +162,7 @@ enum projector_type { PROJECTOR_TYPE_INTERNVL, PROJECTOR_TYPE_LLAMA4, PROJECTOR_TYPE_QWEN2A, + PROJECTOR_TYPE_GLMA, PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_VOXTRAL, PROJECTOR_TYPE_LFM2, @@ -156,6 +170,7 @@ enum projector_type { PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, + PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_UNKNOWN, }; @@ -175,6 +190,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_INTERNVL, "internvl"}, { PROJECTOR_TYPE_LLAMA4, "llama4"}, { PROJECTOR_TYPE_QWEN2A, "qwen2a"}, + { PROJECTOR_TYPE_GLMA, "glma"}, { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, { PROJECTOR_TYPE_LFM2, "lfm2"}, @@ -182,6 +198,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, + { PROJECTOR_TYPE_GLM4V, "glm4v"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { @@ -485,6 +502,8 @@ static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) { } } +void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value); + // // API used internally with mtmd // diff --git a/llama/llama.cpp/tools/mtmd/clip-model.h b/llama/llama.cpp/tools/mtmd/clip-model.h new file mode 100644 index 00000000..f5c41ff1 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/clip-model.h @@ -0,0 +1,300 @@ +#pragma once + +#include "ggml.h" +#include "clip.h" +#include "clip-impl.h" + +#include +#include +#include +#include + +enum ffn_op_type { + FFN_GELU, + FFN_GELU_ERF, + FFN_SILU, + FFN_GELU_QUICK, +}; + +enum norm_type { + NORM_TYPE_NORMAL, + NORM_TYPE_RMS, +}; + +enum patch_merge_type { + PATCH_MERGE_FLAT, + PATCH_MERGE_SPATIAL_UNPAD, +}; + +struct clip_hparams { + int32_t image_size = 0; + int32_t patch_size = 0; + int32_t n_embd = 0; + int32_t n_ff = 0; + int32_t projection_dim = 0; + int32_t n_head = 0; + int32_t n_layer = 0; + // idefics3 + int32_t image_longest_edge = 0; + int32_t image_min_pixels = -1; + int32_t image_max_pixels = -1; + int32_t n_merge = 0; // number of patch merges **per-side** + + float image_mean[3]; + float image_std[3]; + + // for models using dynamic image size, we need to have a smaller image size to warmup + // otherwise, user will get OOM everytime they load the model + int32_t warmup_image_size = 0; + int32_t warmup_audio_size = 3000; + + ffn_op_type ffn_op = FFN_GELU; + + patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT; + + float eps = 1e-6; + float rope_theta = 0.0; + + std::vector image_res_candidates; // for llava-uhd style models + int32_t image_crop_resolution; + std::unordered_set vision_feature_layer; + int32_t attn_window_size = 0; + int32_t n_wa_pattern = 0; + + // audio + int32_t n_mel_bins = 0; // whisper preprocessor + int32_t proj_stack_factor = 0; // ultravox + + // audio-to-mel preprocessor params + int32_t audio_chunk_len = -1; // in seconds + int32_t audio_sample_rate = -1; + int32_t audio_n_fft = -1; + int32_t audio_window_len = -1; + int32_t audio_hop_len = -1; + + // legacy + bool has_llava_projector = false; + int minicpmv_version = 0; + int32_t minicpmv_query_num = 0; // MiniCPM-V query number + + // custom value provided by user, can be undefined if not set + int32_t custom_image_min_tokens = -1; + int32_t custom_image_max_tokens = -1; + + void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) { + const int cur_merge = n_merge == 0 ? 1 : n_merge; + const int patch_area = patch_size * patch_size * cur_merge * cur_merge; + image_min_pixels = (custom_image_min_tokens > 0 ? custom_image_min_tokens : n_tokens_min) * patch_area; + image_max_pixels = (custom_image_max_tokens > 0 ? custom_image_max_tokens : n_tokens_max) * patch_area; + warmup_image_size = static_cast(std::sqrt(image_max_pixels)); + } + + void set_warmup_n_tokens(int n_tokens) { + int n_tok_per_side = static_cast(std::sqrt(n_tokens)); + GGML_ASSERT(n_tok_per_side * n_tok_per_side == n_tokens && "n_tokens must be n*n"); + const int cur_merge = n_merge == 0 ? 1 : n_merge; + warmup_image_size = n_tok_per_side * patch_size * cur_merge; + // TODO: support warmup size for custom token numbers + } +}; + +struct clip_layer { + // attention + ggml_tensor * k_w = nullptr; + ggml_tensor * k_b = nullptr; + ggml_tensor * q_w = nullptr; + ggml_tensor * q_b = nullptr; + ggml_tensor * v_w = nullptr; + ggml_tensor * v_b = nullptr; + ggml_tensor * qkv_w = nullptr; + ggml_tensor * qkv_b = nullptr; + + ggml_tensor * o_w = nullptr; + ggml_tensor * o_b = nullptr; + + ggml_tensor * k_norm = nullptr; + ggml_tensor * q_norm = nullptr; + + // layernorm 1 + ggml_tensor * ln_1_w = nullptr; + ggml_tensor * ln_1_b = nullptr; + + ggml_tensor * ff_up_w = nullptr; + ggml_tensor * ff_up_b = nullptr; + ggml_tensor * ff_gate_w = nullptr; + ggml_tensor * ff_gate_b = nullptr; + ggml_tensor * ff_down_w = nullptr; + ggml_tensor * ff_down_b = nullptr; + + // layernorm 2 + ggml_tensor * ln_2_w = nullptr; + ggml_tensor * ln_2_b = nullptr; + + // layer scale (no bias) + ggml_tensor * ls_1_w = nullptr; + ggml_tensor * ls_2_w = nullptr; + + // qwen3vl deepstack merger + ggml_tensor * deepstack_norm_w = nullptr; + ggml_tensor * deepstack_norm_b = nullptr; + ggml_tensor * deepstack_fc1_w = nullptr; + ggml_tensor * deepstack_fc1_b = nullptr; + ggml_tensor * deepstack_fc2_w = nullptr; + ggml_tensor * deepstack_fc2_b = nullptr; + + bool has_deepstack() const { + return deepstack_fc1_w != nullptr; + } +}; + +struct clip_model { + clip_modality modality = CLIP_MODALITY_VISION; + projector_type proj_type = PROJECTOR_TYPE_MLP; + clip_hparams hparams; + + // embeddings + ggml_tensor * class_embedding = nullptr; + ggml_tensor * patch_embeddings_0 = nullptr; + ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) + ggml_tensor * patch_bias = nullptr; + ggml_tensor * position_embeddings = nullptr; + ggml_tensor * norm_embd_w = nullptr; + ggml_tensor * norm_embd_b = nullptr; + + ggml_tensor * pre_ln_w = nullptr; + ggml_tensor * pre_ln_b = nullptr; + + std::vector layers; + + int32_t n_deepstack_layers = 0; // used by Qwen3-VL, calculated from clip_layer + + ggml_tensor * post_ln_w; + ggml_tensor * post_ln_b; + + ggml_tensor * projection; // TODO: rename it to fc (fully connected layer) + ggml_tensor * mm_fc_w; + ggml_tensor * mm_fc_b; + ggml_tensor * mm_ffn_up_w = nullptr; + ggml_tensor * mm_ffn_up_b = nullptr; + ggml_tensor * mm_ffn_gate_w = nullptr; + ggml_tensor * mm_ffn_gate_b = nullptr; + ggml_tensor * mm_ffn_down_w = nullptr; + ggml_tensor * mm_ffn_down_b = nullptr; + ggml_tensor * mm_post_norm_w = nullptr; + ggml_tensor * mm_post_norm_b = nullptr; + + // LLaVA projection + ggml_tensor * mm_input_norm_w = nullptr; + ggml_tensor * mm_input_norm_b = nullptr; + ggml_tensor * mm_0_w = nullptr; + ggml_tensor * mm_0_b = nullptr; + ggml_tensor * mm_2_w = nullptr; + ggml_tensor * mm_2_b = nullptr; + + ggml_tensor * image_newline = nullptr; + + // Yi type models with mlp+normalization projection + ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4 + ggml_tensor * mm_1_b = nullptr; + ggml_tensor * mm_3_w = nullptr; + ggml_tensor * mm_3_b = nullptr; + ggml_tensor * mm_4_w = nullptr; + ggml_tensor * mm_4_b = nullptr; + + // GLMV-Edge projection + ggml_tensor * mm_model_adapter_conv_w = nullptr; + ggml_tensor * mm_model_adapter_conv_b = nullptr; + + // MobileVLM projection + ggml_tensor * mm_model_mlp_1_w = nullptr; + ggml_tensor * mm_model_mlp_1_b = nullptr; + ggml_tensor * mm_model_mlp_3_w = nullptr; + ggml_tensor * mm_model_mlp_3_b = nullptr; + ggml_tensor * mm_model_block_1_block_0_0_w = nullptr; + ggml_tensor * mm_model_block_1_block_0_1_w = nullptr; + ggml_tensor * mm_model_block_1_block_0_1_b = nullptr; + ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr; + ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr; + ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr; + ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr; + ggml_tensor * mm_model_block_1_block_2_0_w = nullptr; + ggml_tensor * mm_model_block_1_block_2_1_w = nullptr; + ggml_tensor * mm_model_block_1_block_2_1_b = nullptr; + ggml_tensor * mm_model_block_2_block_0_0_w = nullptr; + ggml_tensor * mm_model_block_2_block_0_1_w = nullptr; + ggml_tensor * mm_model_block_2_block_0_1_b = nullptr; + ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr; + ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr; + ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr; + ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr; + ggml_tensor * mm_model_block_2_block_2_0_w = nullptr; + ggml_tensor * mm_model_block_2_block_2_1_w = nullptr; + ggml_tensor * mm_model_block_2_block_2_1_b = nullptr; + + // MobileVLM_V2 projection + ggml_tensor * mm_model_mlp_0_w = nullptr; + ggml_tensor * mm_model_mlp_0_b = nullptr; + ggml_tensor * mm_model_mlp_2_w = nullptr; + ggml_tensor * mm_model_mlp_2_b = nullptr; + ggml_tensor * mm_model_peg_0_w = nullptr; + ggml_tensor * mm_model_peg_0_b = nullptr; + + // MINICPMV projection + ggml_tensor * mm_model_pos_embed_k = nullptr; + ggml_tensor * mm_model_query = nullptr; + ggml_tensor * mm_model_proj = nullptr; + ggml_tensor * mm_model_kv_proj = nullptr; + ggml_tensor * mm_model_attn_q_w = nullptr; + ggml_tensor * mm_model_attn_q_b = nullptr; + ggml_tensor * mm_model_attn_k_w = nullptr; + ggml_tensor * mm_model_attn_k_b = nullptr; + ggml_tensor * mm_model_attn_v_w = nullptr; + ggml_tensor * mm_model_attn_v_b = nullptr; + ggml_tensor * mm_model_attn_o_w = nullptr; + ggml_tensor * mm_model_attn_o_b = nullptr; + ggml_tensor * mm_model_ln_q_w = nullptr; + ggml_tensor * mm_model_ln_q_b = nullptr; + ggml_tensor * mm_model_ln_kv_w = nullptr; + ggml_tensor * mm_model_ln_kv_b = nullptr; + ggml_tensor * mm_model_ln_post_w = nullptr; + ggml_tensor * mm_model_ln_post_b = nullptr; + + // gemma3 + ggml_tensor * mm_input_proj_w = nullptr; + ggml_tensor * mm_soft_emb_norm_w = nullptr; + + // pixtral, glm4v + ggml_tensor * token_embd_img_break = nullptr; + ggml_tensor * mm_patch_merger_w = nullptr; + ggml_tensor * mm_patch_merger_b = nullptr; + + // ultravox / whisper encoder + ggml_tensor * conv1d_1_w = nullptr; + ggml_tensor * conv1d_1_b = nullptr; + ggml_tensor * conv1d_2_w = nullptr; + ggml_tensor * conv1d_2_b = nullptr; + ggml_tensor * mm_norm_pre_w = nullptr; + ggml_tensor * mm_norm_pre_b = nullptr; + ggml_tensor * mm_norm_mid_w = nullptr; + + // cogvlm + ggml_tensor * mm_post_fc_norm_w = nullptr; + ggml_tensor * mm_post_fc_norm_b = nullptr; + ggml_tensor * mm_h_to_4h_w = nullptr; + ggml_tensor * mm_gate_w = nullptr; + ggml_tensor * mm_4h_to_h_w = nullptr; + ggml_tensor * mm_boi = nullptr; + ggml_tensor * mm_eoi = nullptr; + + bool audio_has_avgpool() const { + return proj_type == PROJECTOR_TYPE_QWEN2A + || proj_type == PROJECTOR_TYPE_VOXTRAL; + } + + bool audio_has_stack_frames() const { + return proj_type == PROJECTOR_TYPE_ULTRAVOX + || proj_type == PROJECTOR_TYPE_VOXTRAL; + } +}; + +const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx); diff --git a/llama/llama.cpp/tools/mtmd/clip.cpp b/llama/llama.cpp/tools/mtmd/clip.cpp index 2a325c72..d3a37842 100644 --- a/llama/llama.cpp/tools/mtmd/clip.cpp +++ b/llama/llama.cpp/tools/mtmd/clip.cpp @@ -1,9 +1,9 @@ -// NOTE: This is modified from clip.cpp only for LLaVA, -// so there might be still unnecessary artifacts hanging around -// I'll gradually clean and extend it -// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch #include "clip.h" #include "clip-impl.h" +#include "clip-model.h" +#include "clip-graph.h" +#include "models/models.h" + #include "ggml.h" #include "ggml-cpp.h" #include "ggml-alloc.h" @@ -39,18 +39,6 @@ struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL}; -enum ffn_op_type { - FFN_GELU, - FFN_GELU_ERF, - FFN_SILU, - FFN_GELU_QUICK, -}; - -enum norm_type { - NORM_TYPE_NORMAL, - NORM_TYPE_RMS, -}; - //#define CLIP_DEBUG_FUNCTIONS #ifdef CLIP_DEBUG_FUNCTIONS @@ -162,267 +150,6 @@ static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u #endif -// -// clip layers -// - -enum patch_merge_type { - PATCH_MERGE_FLAT, - PATCH_MERGE_SPATIAL_UNPAD, -}; - -struct clip_hparams { - int32_t image_size = 0; - int32_t patch_size = 0; - int32_t n_embd = 0; - int32_t n_ff = 0; - int32_t projection_dim = 0; - int32_t n_head = 0; - int32_t n_layer = 0; - // idefics3 - int32_t image_longest_edge = 0; - int32_t image_min_pixels = -1; - int32_t image_max_pixels = -1; - int32_t n_merge = 0; // number of patch merges **per-side** - - float image_mean[3]; - float image_std[3]; - - // for models using dynamic image size, we need to have a smaller image size to warmup - // otherwise, user will get OOM everytime they load the model - int32_t warmup_image_size = 0; - int32_t warmup_audio_size = 3000; - - ffn_op_type ffn_op = FFN_GELU; - - patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT; - - float eps = 1e-6; - float rope_theta = 0.0; - - std::vector image_res_candidates; // for llava-uhd style models - int32_t image_crop_resolution; - std::unordered_set vision_feature_layer; - int32_t attn_window_size = 0; - int32_t n_wa_pattern = 0; - - // audio - int32_t n_mel_bins = 0; // whisper preprocessor - int32_t proj_stack_factor = 0; // ultravox - - // legacy - bool has_llava_projector = false; - int minicpmv_version = 0; - int32_t minicpmv_query_num = 0; // MiniCPM-V query number - - // custom value provided by user, can be undefined if not set - int32_t custom_image_min_tokens = -1; - int32_t custom_image_max_tokens = -1; - - void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) { - const int cur_merge = n_merge == 0 ? 1 : n_merge; - const int patch_area = patch_size * patch_size * cur_merge * cur_merge; - image_min_pixels = (custom_image_min_tokens > 0 ? custom_image_min_tokens : n_tokens_min) * patch_area; - image_max_pixels = (custom_image_max_tokens > 0 ? custom_image_max_tokens : n_tokens_max) * patch_area; - warmup_image_size = static_cast(std::sqrt(image_max_pixels)); - } - - void set_warmup_n_tokens(int n_tokens) { - int n_tok_per_side = static_cast(std::sqrt(n_tokens)); - GGML_ASSERT(n_tok_per_side * n_tok_per_side == n_tokens && "n_tokens must be n*n"); - const int cur_merge = n_merge == 0 ? 1 : n_merge; - warmup_image_size = n_tok_per_side * patch_size * cur_merge; - // TODO: support warmup size for custom token numbers - } -}; - -struct clip_layer { - // attention - ggml_tensor * k_w = nullptr; - ggml_tensor * k_b = nullptr; - ggml_tensor * q_w = nullptr; - ggml_tensor * q_b = nullptr; - ggml_tensor * v_w = nullptr; - ggml_tensor * v_b = nullptr; - ggml_tensor * qkv_w = nullptr; - ggml_tensor * qkv_b = nullptr; - - ggml_tensor * o_w = nullptr; - ggml_tensor * o_b = nullptr; - - ggml_tensor * k_norm = nullptr; - ggml_tensor * q_norm = nullptr; - - // layernorm 1 - ggml_tensor * ln_1_w = nullptr; - ggml_tensor * ln_1_b = nullptr; - - ggml_tensor * ff_up_w = nullptr; - ggml_tensor * ff_up_b = nullptr; - ggml_tensor * ff_gate_w = nullptr; - ggml_tensor * ff_gate_b = nullptr; - ggml_tensor * ff_down_w = nullptr; - ggml_tensor * ff_down_b = nullptr; - - // layernorm 2 - ggml_tensor * ln_2_w = nullptr; - ggml_tensor * ln_2_b = nullptr; - - // layer scale (no bias) - ggml_tensor * ls_1_w = nullptr; - ggml_tensor * ls_2_w = nullptr; - - // qwen3vl deepstack merger - ggml_tensor * deepstack_norm_w = nullptr; - ggml_tensor * deepstack_norm_b = nullptr; - ggml_tensor * deepstack_fc1_w = nullptr; - ggml_tensor * deepstack_fc1_b = nullptr; - ggml_tensor * deepstack_fc2_w = nullptr; - ggml_tensor * deepstack_fc2_b = nullptr; - - bool has_deepstack() const { - return deepstack_fc1_w != nullptr; - } -}; - -struct clip_model { - clip_modality modality = CLIP_MODALITY_VISION; - projector_type proj_type = PROJECTOR_TYPE_MLP; - clip_hparams hparams; - - // embeddings - ggml_tensor * class_embedding = nullptr; - ggml_tensor * patch_embeddings_0 = nullptr; - ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) - ggml_tensor * patch_bias = nullptr; - ggml_tensor * position_embeddings = nullptr; - - ggml_tensor * pre_ln_w = nullptr; - ggml_tensor * pre_ln_b = nullptr; - - std::vector layers; - - int32_t n_deepstack_layers = 0; // used by Qwen3-VL, calculated from clip_layer - - ggml_tensor * post_ln_w; - ggml_tensor * post_ln_b; - - ggml_tensor * projection; // TODO: rename it to fc (fully connected layer) - ggml_tensor * mm_fc_w; - ggml_tensor * mm_fc_b; - - // LLaVA projection - ggml_tensor * mm_input_norm_w = nullptr; - ggml_tensor * mm_input_norm_b = nullptr; - ggml_tensor * mm_0_w = nullptr; - ggml_tensor * mm_0_b = nullptr; - ggml_tensor * mm_2_w = nullptr; - ggml_tensor * mm_2_b = nullptr; - - ggml_tensor * image_newline = nullptr; - - // Yi type models with mlp+normalization projection - ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4 - ggml_tensor * mm_1_b = nullptr; - ggml_tensor * mm_3_w = nullptr; - ggml_tensor * mm_3_b = nullptr; - ggml_tensor * mm_4_w = nullptr; - ggml_tensor * mm_4_b = nullptr; - - // GLMV-Edge projection - ggml_tensor * mm_model_adapter_conv_w = nullptr; - ggml_tensor * mm_model_adapter_conv_b = nullptr; - - // MobileVLM projection - ggml_tensor * mm_model_mlp_1_w = nullptr; - ggml_tensor * mm_model_mlp_1_b = nullptr; - ggml_tensor * mm_model_mlp_3_w = nullptr; - ggml_tensor * mm_model_mlp_3_b = nullptr; - ggml_tensor * mm_model_block_1_block_0_0_w = nullptr; - ggml_tensor * mm_model_block_1_block_0_1_w = nullptr; - ggml_tensor * mm_model_block_1_block_0_1_b = nullptr; - ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr; - ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr; - ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr; - ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr; - ggml_tensor * mm_model_block_1_block_2_0_w = nullptr; - ggml_tensor * mm_model_block_1_block_2_1_w = nullptr; - ggml_tensor * mm_model_block_1_block_2_1_b = nullptr; - ggml_tensor * mm_model_block_2_block_0_0_w = nullptr; - ggml_tensor * mm_model_block_2_block_0_1_w = nullptr; - ggml_tensor * mm_model_block_2_block_0_1_b = nullptr; - ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr; - ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr; - ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr; - ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr; - ggml_tensor * mm_model_block_2_block_2_0_w = nullptr; - ggml_tensor * mm_model_block_2_block_2_1_w = nullptr; - ggml_tensor * mm_model_block_2_block_2_1_b = nullptr; - - // MobileVLM_V2 projection - ggml_tensor * mm_model_mlp_0_w = nullptr; - ggml_tensor * mm_model_mlp_0_b = nullptr; - ggml_tensor * mm_model_mlp_2_w = nullptr; - ggml_tensor * mm_model_mlp_2_b = nullptr; - ggml_tensor * mm_model_peg_0_w = nullptr; - ggml_tensor * mm_model_peg_0_b = nullptr; - - // MINICPMV projection - ggml_tensor * mm_model_pos_embed_k = nullptr; - ggml_tensor * mm_model_query = nullptr; - ggml_tensor * mm_model_proj = nullptr; - ggml_tensor * mm_model_kv_proj = nullptr; - ggml_tensor * mm_model_attn_q_w = nullptr; - ggml_tensor * mm_model_attn_q_b = nullptr; - ggml_tensor * mm_model_attn_k_w = nullptr; - ggml_tensor * mm_model_attn_k_b = nullptr; - ggml_tensor * mm_model_attn_v_w = nullptr; - ggml_tensor * mm_model_attn_v_b = nullptr; - ggml_tensor * mm_model_attn_o_w = nullptr; - ggml_tensor * mm_model_attn_o_b = nullptr; - ggml_tensor * mm_model_ln_q_w = nullptr; - ggml_tensor * mm_model_ln_q_b = nullptr; - ggml_tensor * mm_model_ln_kv_w = nullptr; - ggml_tensor * mm_model_ln_kv_b = nullptr; - ggml_tensor * mm_model_ln_post_w = nullptr; - ggml_tensor * mm_model_ln_post_b = nullptr; - - // gemma3 - ggml_tensor * mm_input_proj_w = nullptr; - ggml_tensor * mm_soft_emb_norm_w = nullptr; - - // pixtral - ggml_tensor * token_embd_img_break = nullptr; - ggml_tensor * mm_patch_merger_w = nullptr; - - // ultravox / whisper encoder - ggml_tensor * conv1d_1_w = nullptr; - ggml_tensor * conv1d_1_b = nullptr; - ggml_tensor * conv1d_2_w = nullptr; - ggml_tensor * conv1d_2_b = nullptr; - ggml_tensor * mm_norm_pre_w = nullptr; - ggml_tensor * mm_norm_mid_w = nullptr; - - // cogvlm - ggml_tensor * mm_post_fc_norm_w = nullptr; - ggml_tensor * mm_post_fc_norm_b = nullptr; - ggml_tensor * mm_h_to_4h_w = nullptr; - ggml_tensor * mm_gate_w = nullptr; - ggml_tensor * mm_4h_to_h_w = nullptr; - ggml_tensor * mm_boi = nullptr; - ggml_tensor * mm_eoi = nullptr; - - bool audio_has_avgpool() const { - return proj_type == PROJECTOR_TYPE_QWEN2A - || proj_type == PROJECTOR_TYPE_VOXTRAL; - } - - bool audio_has_stack_frames() const { - return proj_type == PROJECTOR_TYPE_ULTRAVOX - || proj_type == PROJECTOR_TYPE_VOXTRAL; - } -}; - struct clip_ctx { clip_model model; @@ -505,1599 +232,150 @@ struct clip_ctx { } }; -struct clip_graph { - clip_ctx * ctx; - const clip_model & model; - const clip_hparams & hparams; +// +// clip_graph +// - // we only support single image per batch - const clip_image_f32 & img; +clip_graph::clip_graph(clip_ctx * ctx, const clip_image_f32 & img) : + model(ctx->model), + hparams(model.hparams), + proj_type(ctx->proj_type()), + img(img), + patch_size(hparams.patch_size), + n_patches_x(img.nx / patch_size), + n_patches_y(img.ny / patch_size), + n_patches(n_patches_x * n_patches_y), + n_embd(hparams.n_embd), + n_head(hparams.n_head), + d_head(n_embd / n_head), + n_layer(hparams.n_layer), + n_mmproj_embd(clip_n_mmproj_embd(ctx)), + eps(hparams.eps), + kq_scale(1.0f / sqrtf((float)d_head)), + flash_attn_type(ctx->flash_attn_type), + debug_graph(ctx->debug_graph), + debug_print_tensors(ctx->debug_print_tensors) { + struct ggml_init_params params = { + /*.mem_size =*/ ctx->buf_compute_meta.size(), + /*.mem_buffer =*/ ctx->buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + ctx0_ptr.reset(ggml_init(params)); + ctx0 = ctx0_ptr.get(); + gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false); +} - const int patch_size; - const int n_patches_x; - const int n_patches_y; - const int n_patches; - const int n_embd; - const int n_head; - const int d_head; - const int n_layer; - const float eps; - const float kq_scale; - - ggml_context_ptr ctx0_ptr; - ggml_context * ctx0; - ggml_cgraph * gf; - - clip_graph(clip_ctx * ctx, const clip_image_f32 & img) : - ctx(ctx), - model(ctx->model), - hparams(model.hparams), - img(img), - patch_size(hparams.patch_size), - n_patches_x(img.nx / patch_size), - n_patches_y(img.ny / patch_size), - n_patches(n_patches_x * n_patches_y), - n_embd(hparams.n_embd), - n_head(hparams.n_head), - d_head(n_embd / n_head), - n_layer(hparams.n_layer), - eps(hparams.eps), - kq_scale(1.0f / sqrtf((float)d_head)) { - struct ggml_init_params params = { - /*.mem_size =*/ ctx->buf_compute_meta.size(), - /*.mem_buffer =*/ ctx->buf_compute_meta.data(), - /*.no_alloc =*/ true, - }; - ctx0_ptr.reset(ggml_init(params)); - ctx0 = ctx0_ptr.get(); - gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false); - } - - ggml_cgraph * build_siglip() { - ggml_tensor * inp = build_inp(); - - ggml_tensor * learned_pos_embd = model.position_embeddings; - if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) { - learned_pos_embd = resize_position_embeddings(); - } - - ggml_tensor * cur = build_vit( - inp, n_patches, - NORM_TYPE_NORMAL, - hparams.ffn_op, - learned_pos_embd, - nullptr); - - if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) { - const int batch_size = 1; - GGML_ASSERT(n_patches_x == n_patches_y); - const int patches_per_image = n_patches_x; - const int kernel_size = hparams.n_merge; - - cur = ggml_transpose(ctx0, cur); - cur = ggml_cont_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size); - - // doing a pool2d to reduce the number of output tokens - cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0); - cur = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[0], n_embd, batch_size); - cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); - - // apply norm before projection - cur = ggml_rms_norm(ctx0, cur, eps); - cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w); - - // apply projection - cur = ggml_mul_mat(ctx0, - ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)), - cur); - - } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) { - // pixel_shuffle - // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 - const int scale_factor = model.hparams.n_merge; - cur = build_patch_merge_permute(cur, scale_factor); - cur = ggml_mul_mat(ctx0, model.projection, cur); - - } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) { - // pixel unshuffle block - const int scale_factor = model.hparams.n_merge; - cur = build_patch_merge_permute(cur, scale_factor); - - // projection - cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm - cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); - cur = ggml_add(ctx0, cur, model.mm_input_norm_b); - - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - cur = ggml_add(ctx0, cur, model.mm_1_b); - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); - cur = ggml_add(ctx0, cur, model.mm_2_b); - - } else if (ctx->proj_type() == PROJECTOR_TYPE_JANUS_PRO) { - cur = build_ffn(cur, - model.mm_0_w, model.mm_0_b, - nullptr, nullptr, - model.mm_1_w, model.mm_1_b, - hparams.ffn_op, - -1); - - } else { - GGML_ABORT("SigLIP: Unsupported projector type"); - } - - // build the graph +void clip_graph::cb(ggml_tensor * cur0, const char * name, int il) const { + if (debug_graph) { + ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0)); + std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name; + ggml_set_name(cur, cur_name.c_str()); + ggml_set_output(cur); ggml_build_forward_expand(gf, cur); - - return gf; + debug_print_tensors.push_back(cur); } +} - ggml_cgraph * build_pixtral() { - const int n_merge = hparams.n_merge; +// siglip2 naflex +ggml_tensor * clip_graph::resize_position_embeddings(uint32_t interpolation_mode) { + ggml_tensor * pos_embd = model.position_embeddings; + const int height = img.ny / patch_size; + const int width = img.nx / patch_size; + const uint32_t mode = interpolation_mode; + const int n_per_side = (int)std::sqrt(pos_embd->ne[1]); - // 2D input positions - ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); - ggml_set_name(pos_h, "pos_h"); - ggml_set_input(pos_h); - - ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); - ggml_set_name(pos_w, "pos_w"); - ggml_set_input(pos_w); - - auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { - return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, true); - }; - - ggml_tensor * inp = build_inp(); - ggml_tensor * cur = build_vit( - inp, n_patches, - NORM_TYPE_RMS, - hparams.ffn_op, - nullptr, // no learned pos embd - add_pos); - - // mistral small 3.1 patch merger - // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67 - if (model.mm_patch_merger_w) { - GGML_ASSERT(hparams.n_merge > 0); - - cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w); - - // reshape image tokens to 2D grid - cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y); - cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, n_embd] - cur = ggml_cont(ctx0, cur); - - // torch.nn.functional.unfold is just an im2col under the hood - // we just need a dummy kernel to make it work - ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0); - cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type); - - // project to n_embd - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); - cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur); - } - - // LlavaMultiModalProjector (always using GELU activation) - { - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - if (model.mm_1_b) { - cur = ggml_add(ctx0, cur, model.mm_1_b); - } - - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); - if (model.mm_2_b) { - cur = ggml_add(ctx0, cur, model.mm_2_b); - } - } - - // arrangement of the [IMG_BREAK] token - if (model.token_embd_img_break) { - // not efficient, but works - // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows] - // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension - // after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows] - - const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y; - const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x; - const int p_total = p_x * p_y; - const int n_embd_text = cur->ne[0]; - const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row - - ggml_tensor * tmp = ggml_reshape_3d(ctx0, cur, n_embd_text, p_x, p_y); - ggml_tensor * tok = ggml_new_tensor_3d(ctx0, tmp->type, n_embd_text, 1, p_y); - tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor - tok = ggml_add(ctx0, tok, model.token_embd_img_break); - tmp = ggml_concat(ctx0, tmp, tok, 1); - cur = ggml_view_2d(ctx0, tmp, - n_embd_text, n_tokens_output, - ggml_row_size(tmp->type, n_embd_text), 0); - } - - // build the graph - ggml_build_forward_expand(gf, cur); - - return gf; - } - - // Qwen2VL and Qwen2.5VL use M-RoPE - ggml_cgraph * build_qwen2vl() { - GGML_ASSERT(model.patch_bias == nullptr); - GGML_ASSERT(model.class_embedding == nullptr); - - const int batch_size = 1; - const bool use_window_attn = hparams.n_wa_pattern > 0; - const int n_wa_pattern = hparams.n_wa_pattern; - const int n_pos = n_patches; - const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position - - norm_type norm_t = ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL - ? NORM_TYPE_RMS // qwen 2.5 vl - : NORM_TYPE_NORMAL; // qwen 2 vl - - int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; - - ggml_tensor * inp_raw = build_inp_raw(); - ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - - GGML_ASSERT(img.nx % (patch_size * 2) == 0); - GGML_ASSERT(img.ny % (patch_size * 2) == 0); - - // second conv dimension - { - auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_add(ctx0, inp, inp_1); - - inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b] - inp = ggml_cont_4d( - ctx0, inp, - n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); - inp = ggml_reshape_4d( - ctx0, inp, - n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); - inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); - inp = ggml_cont_3d( - ctx0, inp, - n_embd, n_patches_x * n_patches_y, batch_size); - } - - ggml_tensor * inpL = inp; - ggml_tensor * window_mask = nullptr; - ggml_tensor * window_idx = nullptr; - ggml_tensor * inv_window_idx = nullptr; - - ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); - ggml_set_name(positions, "positions"); - ggml_set_input(positions); - - // pre-layernorm - if (model.pre_ln_w) { - inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); - } - - if (use_window_attn) { - // handle window attention inputs - inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); - ggml_set_name(inv_window_idx, "inv_window_idx"); - ggml_set_input(inv_window_idx); - // mask for window attention - window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos); - ggml_set_name(window_mask, "window_mask"); - ggml_set_input(window_mask); - - // if flash attn is used, we need to pad the mask and cast to f16 - if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { - int n_pad = GGML_PAD(window_mask->ne[1], GGML_KQ_MASK_PAD) - window_mask->ne[1]; - if (n_pad > 0) { - window_mask = ggml_pad(ctx0, window_mask, 0, n_pad, 0, 0); - } - window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); - } - - // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size] - GGML_ASSERT(batch_size == 1); - inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4); - inpL = ggml_get_rows(ctx0, inpL, inv_window_idx); - inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size); - } - - // loop over layers - for (int il = 0; il < n_layer; il++) { - auto & layer = model.layers[il]; - const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true; - - ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states - - // layernorm1 - cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); - cb(cur, "ln1", il); - - // self-attention - { - ggml_tensor * Qcur = ggml_add(ctx0, - ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b); - ggml_tensor * Kcur = ggml_add(ctx0, - ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b); - ggml_tensor * Vcur = ggml_add(ctx0, - ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b); - - Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches); - Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches); - Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - // apply M-RoPE - Qcur = ggml_rope_multi( - ctx0, Qcur, positions, nullptr, - d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); - Kcur = ggml_rope_multi( - ctx0, Kcur, positions, nullptr, - d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); - - cb(Qcur, "Qcur_rope", il); - cb(Kcur, "Kcur_rope", il); - - ggml_tensor * attn_mask = full_attn ? nullptr : window_mask; - - cur = build_attn(layer.o_w, layer.o_b, - Qcur, Kcur, Vcur, attn_mask, kq_scale, il); - cb(cur, "attn_out", il); - } - - // re-add the layer input, e.g., residual - cur = ggml_add(ctx0, cur, inpL); - - inpL = cur; // inpL = residual, cur = hidden_states - - cb(cur, "ffn_inp", il); - - // layernorm2 - cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); - cb(cur, "ffn_inp_normed", il); - - // ffn - cur = build_ffn(cur, - layer.ff_up_w, layer.ff_up_b, - layer.ff_gate_w, layer.ff_gate_b, - layer.ff_down_w, layer.ff_down_b, - hparams.ffn_op, il); - - cb(cur, "ffn_out", il); - - // residual 2 - cur = ggml_add(ctx0, inpL, cur); - cb(cur, "layer_out", il); - - inpL = cur; - } - - // post-layernorm - if (model.post_ln_w) { - inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer); - } - - // multimodal projection - ggml_tensor * embeddings = inpL; - embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); - - embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); - - // GELU activation - embeddings = ggml_gelu(ctx0, embeddings); - - // Second linear layer - embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); - - if (use_window_attn) { - window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); - ggml_set_name(window_idx, "window_idx"); - ggml_set_input(window_idx); - - // embeddings shape: [n_embd, n_patches_x * n_patches_y, batch_size] - GGML_ASSERT(batch_size == 1); - embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4); - embeddings = ggml_get_rows(ctx0, embeddings, window_idx); - embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4, batch_size); - } - - // build the graph - ggml_build_forward_expand(gf, embeddings); - - return gf; - } - - // Qwen3VL - ggml_cgraph * build_qwen3vl() { - GGML_ASSERT(model.patch_bias != nullptr); - GGML_ASSERT(model.position_embeddings != nullptr); - GGML_ASSERT(model.class_embedding == nullptr); - - const int batch_size = 1; - const int n_pos = n_patches; - const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position - - norm_type norm_t = NORM_TYPE_NORMAL; - - int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; - - ggml_tensor * inp_raw = build_inp_raw(); - ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - - GGML_ASSERT(img.nx % (patch_size * 2) == 0); - GGML_ASSERT(img.ny % (patch_size * 2) == 0); - - // second conv dimension - { - auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_add(ctx0, inp, inp_1); - - inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b] - inp = ggml_cont_4d( - ctx0, inp, - n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); - inp = ggml_reshape_4d( - ctx0, inp, - n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); - inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); - inp = ggml_cont_3d( - ctx0, inp, - n_embd, n_patches_x * n_patches_y, batch_size); - } - - // add patch bias - if (model.patch_bias != nullptr) { - inp = ggml_add(ctx0, inp, model.patch_bias); - cb(inp, "patch_bias", -1); - } - - // calculate absolute position embedding and apply - ggml_tensor * learned_pos_embd = resize_position_embeddings(); - learned_pos_embd = ggml_cont_4d( - ctx0, learned_pos_embd, - n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); - learned_pos_embd = ggml_reshape_4d( - ctx0, learned_pos_embd, - n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); - learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3); - learned_pos_embd = ggml_cont_3d( - ctx0, learned_pos_embd, - n_embd, n_patches_x * n_patches_y, batch_size); - inp = ggml_add(ctx0, inp, learned_pos_embd); - cb(inp, "inp_pos_emb", -1); - - ggml_tensor * inpL = inp; - - ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); - ggml_set_name(positions, "positions"); - ggml_set_input(positions); - - // pre-layernorm - if (model.pre_ln_w) { - inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); - } - - // deepstack features (stack along the feature dimension), [n_embd * len(deepstack_layers), n_patches_x * n_patches_y, batch_size] - ggml_tensor * deepstack_features = nullptr; - const int merge_factor = hparams.n_merge > 0 ? hparams.n_merge * hparams.n_merge : 4; // default 2x2=4 for qwen3vl - - // loop over layers - for (int il = 0; il < n_layer; il++) { - auto & layer = model.layers[il]; - - ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states - - // layernorm1 - cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); - cb(cur, "ln1", il); - - // self-attention - { - cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); - cur = ggml_add(ctx0, cur, layer.qkv_b); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, - /* nb1 */ ggml_row_size(cur->type, d_head), - /* nb2 */ cur->nb[1], - /* offset */ 0); - - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, - /* nb1 */ ggml_row_size(cur->type, d_head), - /* nb2 */ cur->nb[1], - /* offset */ ggml_row_size(cur->type, n_embd)); - - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, - /* nb1 */ ggml_row_size(cur->type, d_head), - /* nb2 */ cur->nb[1], - /* offset */ ggml_row_size(cur->type, 2 * n_embd)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - // apply M-RoPE - Qcur = ggml_rope_multi( - ctx0, Qcur, positions, nullptr, - d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); - Kcur = ggml_rope_multi( - ctx0, Kcur, positions, nullptr, - d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); - - cb(Qcur, "Qcur_rope", il); - cb(Kcur, "Kcur_rope", il); - - cur = build_attn(layer.o_w, layer.o_b, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); - cb(cur, "attn_out", il); - } - - // re-add the layer input, e.g., residual - cur = ggml_add(ctx0, cur, inpL); - - inpL = cur; // inpL = residual, cur = hidden_states - - cb(cur, "ffn_inp", il); - - // layernorm2 - cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); - cb(cur, "ffn_inp_normed", il); - - // ffn - cur = build_ffn(cur, - layer.ff_up_w, layer.ff_up_b, - layer.ff_gate_w, layer.ff_gate_b, - layer.ff_down_w, layer.ff_down_b, - hparams.ffn_op, il); - - cb(cur, "ffn_out", il); - - // residual 2 - cur = ggml_add(ctx0, inpL, cur); - cb(cur, "layer_out", il); - - if (layer.has_deepstack()) { - ggml_tensor * feat = ggml_reshape_3d(ctx0, cur, n_embd * merge_factor, n_pos / merge_factor, batch_size); - feat = build_norm(feat, layer.deepstack_norm_w, layer.deepstack_norm_b, norm_t, eps, il); - feat = build_ffn(feat, - layer.deepstack_fc1_w, layer.deepstack_fc1_b, - nullptr, nullptr, - layer.deepstack_fc2_w, layer.deepstack_fc2_b, - ffn_op_type::FFN_GELU, il); - - if(!deepstack_features) { - deepstack_features = feat; - } else { - // concat along the feature dimension - deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0); - } - } - - inpL = cur; - } - - // post-layernorm - if (model.post_ln_w) { - inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer); - } - - // multimodal projection - ggml_tensor * embeddings = inpL; - embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); - - embeddings = build_ffn(embeddings, - model.mm_0_w, model.mm_0_b, - nullptr, nullptr, - model.mm_1_w, model.mm_1_b, - ffn_op_type::FFN_GELU, -1); - - embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension - - // build the graph - ggml_build_forward_expand(gf, embeddings); - - return gf; - } - - ggml_cgraph * build_minicpmv() { - GGML_ASSERT(model.class_embedding == nullptr); - const int n_pos = n_patches; - const int n_embd_proj = clip_n_mmproj_embd(ctx); - - // position embeddings for the projector (not for ViT) - // see: https://huggingface.co/openbmb/MiniCPM-o-2_6/blob/main/resampler.py#L70 - // base frequency omega - ggml_tensor * omega = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_embd_proj / 4); - ggml_set_name(omega, "omega"); - ggml_set_input(omega); - - // 2D input positions (using float for sinusoidal embeddings) - ggml_tensor * pos_h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_pos); - ggml_set_name(pos_h, "pos_h"); - ggml_set_input(pos_h); - ggml_tensor * pos_w = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_pos); - ggml_set_name(pos_w, "pos_w"); - ggml_set_input(pos_w); - - // for selecting learned pos embd, used by ViT - struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); - ggml_set_name(positions, "positions"); - ggml_set_input(positions); - - ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions); - - ggml_tensor * inp = build_inp(); - ggml_tensor * embeddings = build_vit( - inp, n_pos, - NORM_TYPE_NORMAL, - hparams.ffn_op, - learned_pos_embd, - nullptr); - - // resampler projector (it is just another transformer) - - ggml_tensor * q = model.mm_model_query; - ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings); - - // norm - q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1); - v = build_norm(v, model.mm_model_ln_kv_w, model.mm_model_ln_kv_b, NORM_TYPE_NORMAL, eps, -1); - - // calculate sinusoidal pos embd - ggml_tensor * pos_embed = nullptr; - { - // outer product - ggml_tensor * omega_b = ggml_repeat_4d(ctx0, omega, omega->ne[0], n_pos, 1, 1); // n_pos rows - ggml_tensor * theta_x = ggml_mul(ctx0, omega_b, pos_w); - ggml_tensor * theta_y = ggml_mul(ctx0, omega_b, pos_h); - // sin and cos - ggml_tensor * pos_embd_x = ggml_concat( - ctx0, - ggml_sin(ctx0, theta_x), - ggml_cos(ctx0, theta_x), - 0 // concat on first dim - ); - ggml_tensor * pos_embd_y = ggml_concat( - ctx0, - ggml_sin(ctx0, theta_y), - ggml_cos(ctx0, theta_y), - 0 // concat on first dim - ); - pos_embed = ggml_concat(ctx0, pos_embd_x, pos_embd_y, 0); - } - - // k = v + pos_embed - ggml_tensor * k = ggml_add(ctx0, v, pos_embed); - - // attention - { - const int d_head = 128; - int n_head = n_embd_proj/d_head; - // Use actual config value if available, otherwise fall back to hardcoded values - int num_query = ctx->model.hparams.minicpmv_query_num; - ggml_tensor * Q = ggml_add(ctx0, - ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), - model.mm_model_attn_q_b); - ggml_tensor * K = ggml_add(ctx0, - ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), - model.mm_model_attn_k_b); - ggml_tensor * V = ggml_add(ctx0, - ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), - model.mm_model_attn_v_b); - - Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_query); - K = ggml_reshape_3d(ctx0, K, d_head, n_head, n_pos); - V = ggml_reshape_3d(ctx0, V, d_head, n_head, n_pos); - - cb(Q, "resampler_Q", -1); - cb(K, "resampler_K", -1); - cb(V, "resampler_V", -1); - - float resampler_kq_scale = 1.0f/ sqrtf(float(d_head)); - embeddings = build_attn( - model.mm_model_attn_o_w, - model.mm_model_attn_o_b, - Q, K, V, nullptr, resampler_kq_scale, -1); - cb(embeddings, "resampler_attn_out", -1); - } - // layernorm - embeddings = build_norm(embeddings, model.mm_model_ln_post_w, model.mm_model_ln_post_b, NORM_TYPE_NORMAL, eps, -1); - - // projection - embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings); - - // build the graph - ggml_build_forward_expand(gf, embeddings); - - return gf; - } - - ggml_cgraph * build_internvl() { - GGML_ASSERT(model.class_embedding != nullptr); - GGML_ASSERT(model.position_embeddings != nullptr); - - const int n_pos = n_patches + 1; - ggml_tensor * inp = build_inp(); - - // add CLS token - inp = ggml_concat(ctx0, inp, model.class_embedding, 1); - - // The larger models use a different ViT, which uses RMS norm instead of layer norm - // ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188 - norm_type norm_t = (hparams.n_embd == 3200 && hparams.n_layer == 45) - ? NORM_TYPE_RMS // 6B ViT (Used by InternVL 2.5/3 - 26B, 38B, 78B) - : NORM_TYPE_NORMAL; // 300M ViT (Used by all smaller InternVL models) - - ggml_tensor * cur = build_vit( - inp, n_pos, - norm_t, - hparams.ffn_op, - model.position_embeddings, - nullptr); - - // remove CLS token - cur = ggml_view_2d(ctx0, cur, - n_embd, n_patches, - ggml_row_size(cur->type, n_embd), 0); - - // pixel shuffle - { - const int scale_factor = model.hparams.n_merge; - const int bsz = 1; // batch size, always 1 for now since we don't support batching - const int height = n_patches_y; - const int width = n_patches_x; - GGML_ASSERT(scale_factor > 0); - cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_cont_4d(ctx0, cur, - n_embd * scale_factor * scale_factor, - height / scale_factor, - width / scale_factor, - bsz); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - // flatten to 2D - cur = ggml_cont_2d(ctx0, cur, - n_embd * scale_factor * scale_factor, - cur->ne[1] * cur->ne[2]); - } - - // projector (always using GELU activation) - { - // projector LayerNorm uses pytorch's default eps = 1e-5 - // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79 - cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1); - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - cur = ggml_add(ctx0, cur, model.mm_1_b); - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_3_w, cur); - cur = ggml_add(ctx0, cur, model.mm_3_b); - } - - // build the graph - ggml_build_forward_expand(gf, cur); - - return gf; - } - - ggml_cgraph * build_llama4() { - GGML_ASSERT(model.class_embedding != nullptr); - GGML_ASSERT(model.position_embeddings != nullptr); - - const int n_pos = n_patches + 1; // +1 for [CLS] - - // 2D input positions - ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); - ggml_set_name(pos_h, "pos_h"); - ggml_set_input(pos_h); - - ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); - ggml_set_name(pos_w, "pos_w"); - ggml_set_input(pos_w); - - ggml_tensor * inp = build_inp_raw(); - - // Llama4UnfoldConvolution - { - ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0, - patch_size, patch_size, 3, n_embd); - inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type); - inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); - inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); - cb(inp, "patch_conv", -1); - } - - // add CLS token - inp = ggml_concat(ctx0, inp, model.class_embedding, 1); - - // build ViT with 2D position embeddings - auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { - // first half is X axis and second half is Y axis - // ref: https://github.com/huggingface/transformers/blob/40a493c7ed4f19f08eadb0639cf26d49bfa5e180/src/transformers/models/llama4/modeling_llama4.py#L1312 - // ref: https://github.com/Blaizzy/mlx-vlm/blob/a57156aa87b33cca6e5ee6cfc14dd4ef8f611be6/mlx_vlm/models/llama4/vision.py#L441 - return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); - }; - ggml_tensor * cur = build_vit( - inp, n_pos, - NORM_TYPE_NORMAL, - hparams.ffn_op, - model.position_embeddings, - add_pos); - - // remove CLS token - cur = ggml_view_2d(ctx0, cur, - n_embd, n_patches, - ggml_row_size(cur->type, n_embd), 0); - - // pixel shuffle - // based on Llama4VisionPixelShuffleMLP - // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151 - { - const int scale_factor = model.hparams.n_merge; - const int bsz = 1; // batch size, always 1 for now since we don't support batching - GGML_ASSERT(scale_factor > 0); - GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images - cur = ggml_reshape_4d(ctx0, cur, - n_embd * scale_factor, - n_patches_x / scale_factor, - n_patches_y, - bsz); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_cont_4d(ctx0, cur, - n_embd * scale_factor * scale_factor, - n_patches_x / scale_factor, - n_patches_y / scale_factor, - bsz); - //cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - // flatten to 2D - cur = ggml_cont_2d(ctx0, cur, - n_embd * scale_factor * scale_factor, - n_patches / scale_factor / scale_factor); - cb(cur, "pixel_shuffle", -1); - } - - // based on Llama4VisionMLP2 (always uses GELU activation, no bias) - { - cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur); - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur); - cur = ggml_gelu(ctx0, cur); - cb(cur, "adapter_mlp", -1); - } - - // Llama4MultiModalProjector - cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur); - cb(cur, "projected", -1); - - // build the graph - ggml_build_forward_expand(gf, cur); - - return gf; - } - - ggml_cgraph * build_kimivl() { - // 2D input positions - ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); - ggml_set_name(pos_h, "pos_h"); - ggml_set_input(pos_h); - - ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); - ggml_set_name(pos_w, "pos_w"); - ggml_set_input(pos_w); - - ggml_tensor * learned_pos_embd = resize_position_embeddings(); - - // build ViT with 2D position embeddings - auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { - // first half is X axis and second half is Y axis - return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); - }; - - ggml_tensor * inp = build_inp(); - ggml_tensor * cur = build_vit( - inp, n_patches, - NORM_TYPE_NORMAL, - hparams.ffn_op, - learned_pos_embd, - add_pos); - - cb(cur, "vit_out", -1); - - { - // patch_merger - const int scale_factor = model.hparams.n_merge; - cur = build_patch_merge_permute(cur, scale_factor); - - // projection norm - int proj_inp_dim = cur->ne[0]; - cur = ggml_view_2d(ctx0, cur, - n_embd, cur->ne[1] * scale_factor * scale_factor, - ggml_row_size(cur->type, n_embd), 0); - cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm - cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); - cur = ggml_add(ctx0, cur, model.mm_input_norm_b); - cur = ggml_view_2d(ctx0, cur, - proj_inp_dim, cur->ne[1] / scale_factor / scale_factor, - ggml_row_size(cur->type, proj_inp_dim), 0); - cb(cur, "proj_inp_normed", -1); - - // projection mlp - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - cur = ggml_add(ctx0, cur, model.mm_1_b); - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); - cur = ggml_add(ctx0, cur, model.mm_2_b); - cb(cur, "proj_out", -1); - } - - // build the graph - ggml_build_forward_expand(gf, cur); - - return gf; - } - - // this graph is used by llava, granite and glm - // due to having embedding_stack (used by granite), we cannot reuse build_vit - ggml_cgraph * build_llava() { - const int batch_size = 1; - const int n_pos = n_patches + (model.class_embedding ? 1 : 0); - - GGML_ASSERT(n_patches_x == n_patches_y && "only square images supported"); - - // Calculate the deepest feature layer based on hparams and projector type - int max_feature_layer = n_layer; - { - // Get the index of the second to last layer; this is the default for models that have a llava projector - int il_last = hparams.n_layer - 1; - int deepest_feature_layer = -1; - - if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV || ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) { - il_last += 1; - } - - // If we set explicit vision feature layers, only go up to the deepest one - // NOTE: only used by granite-vision models for now - for (const auto & feature_layer : hparams.vision_feature_layer) { - if (feature_layer > deepest_feature_layer) { - deepest_feature_layer = feature_layer; - } - } - max_feature_layer = deepest_feature_layer < 0 ? il_last : deepest_feature_layer; - } - - ggml_tensor * inp = build_inp(); - - // concat class_embeddings and patch_embeddings - if (model.class_embedding) { - inp = ggml_concat(ctx0, inp, model.class_embedding, 1); - } - - ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); - ggml_set_name(positions, "positions"); - ggml_set_input(positions); - - inp = ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions)); - - ggml_tensor * inpL = inp; - - // pre-layernorm - if (model.pre_ln_w) { - inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, NORM_TYPE_NORMAL, eps, -1); - cb(inpL, "pre_ln", -1); - } - - std::vector embedding_stack; - const auto & vision_feature_layer = hparams.vision_feature_layer; - - // loop over layers - for (int il = 0; il < max_feature_layer; il++) { - auto & layer = model.layers[il]; - ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states - - // If this is an embedding feature layer, save the output. - // NOTE: 0 index here refers to the input to the encoder. - if (vision_feature_layer.find(il) != vision_feature_layer.end()) { - embedding_stack.push_back(cur); - } - - // layernorm1 - cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); - cb(cur, "layer_inp_normed", il); - - // self-attention - { - ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); - if (layer.q_b) { - Qcur = ggml_add(ctx0, Qcur, layer.q_b); - } - - ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); - if (layer.k_b) { - Kcur = ggml_add(ctx0, Kcur, layer.k_b); - } - - ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); - if (layer.v_b) { - Vcur = ggml_add(ctx0, Vcur, layer.v_b); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos); - Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos); - Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - cur = build_attn(layer.o_w, layer.o_b, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); - cb(cur, "attn_out", il); - } - - // re-add the layer input, e.g., residual - cur = ggml_add(ctx0, cur, inpL); - - inpL = cur; // inpL = residual, cur = hidden_states - - cb(cur, "ffn_inp", il); - - // layernorm2 - cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); - cb(cur, "ffn_inp_normed", il); - - // ffn - cur = build_ffn(cur, - layer.ff_up_w, layer.ff_up_b, - layer.ff_gate_w, layer.ff_gate_b, - layer.ff_down_w, layer.ff_down_b, - hparams.ffn_op, il); - - cb(cur, "ffn_out", il); - - // residual 2 - cur = ggml_add(ctx0, inpL, cur); - cb(cur, "layer_out", il); - - inpL = cur; - } - - // post-layernorm - if (model.post_ln_w) { - inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, eps, -1); - } - - ggml_tensor * embeddings = inpL; - - // process vision feature layers (used by granite) - { - // final layer is a vision feature layer - if (vision_feature_layer.find(max_feature_layer) != vision_feature_layer.end()) { - embedding_stack.push_back(inpL); - } - - // If feature layers are explicitly set, stack them (if we have multiple) - if (!embedding_stack.empty()) { - embeddings = embedding_stack[0]; - for (size_t i = 1; i < embedding_stack.size(); i++) { - embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0); - } - } - } - - // llava projector (also used by granite) - if (ctx->model.hparams.has_llava_projector) { - embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); - - ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); - ggml_set_name(patches, "patches"); - ggml_set_input(patches); - - // shape [1, 576, 1024] - // ne is whcn, ne = [1024, 576, 1, 1] - embeddings = ggml_get_rows(ctx0, embeddings, patches); - - // print_tensor_info(embeddings, "embeddings"); - - // llava projector - if (ctx->proj_type() == PROJECTOR_TYPE_MLP) { - embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); - - embeddings = ggml_gelu(ctx0, embeddings); - if (model.mm_2_w) { - embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); - } - } - else if (ctx->proj_type() == PROJECTOR_TYPE_MLP_NORM) { - embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); - // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); - // First LayerNorm - embeddings = ggml_norm(ctx0, embeddings, eps); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w), - model.mm_1_b); - - // GELU activation - embeddings = ggml_gelu(ctx0, embeddings); - - // Second linear layer - embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_3_b); - - // Second LayerNorm - embeddings = ggml_norm(ctx0, embeddings, eps); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w), - model.mm_4_b); - } - else if (ctx->proj_type() == PROJECTOR_TYPE_LDP) { - // MobileVLM projector - int n_patch = 24; - ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings); - mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b); - mlp_1 = ggml_gelu(ctx0, mlp_1); - ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1); - mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b); - // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1] - - // block 1 - ggml_tensor * block_1 = nullptr; - { - // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24] - mlp_3 = ggml_permute(ctx0, mlp_3, 1, 0, 2, 3); - mlp_3 = ggml_cont_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]); - // stride = 1, padding = 1, bias is nullptr - block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1); - - // layer norm - // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3)); - // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1] - block_1 = ggml_norm(ctx0, block_1, eps); - block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b); - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3)); - - // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] - // hardswish - ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1); - - block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0); - // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] - // pointwise conv - block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]); - block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1); - block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b); - block_1 = ggml_relu(ctx0, block_1); - block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1); - block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b); - block_1 = ggml_hardsigmoid(ctx0, block_1); - // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1] - block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]); - block_1 = ggml_mul(ctx0, block_1_hw, block_1); - - int w = block_1->ne[0], h = block_1->ne[1]; - block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]); - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3)); - - // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1] - block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1); - block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]); - - // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1] - block_1 = ggml_norm(ctx0, block_1, eps); - block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b); - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3)); - // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] - // residual - block_1 = ggml_add(ctx0, mlp_3, block_1); - } - - // block_2 - { - // stride = 2 - block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1); - - // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1] - // layer norm - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3)); - // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1] - block_1 = ggml_norm(ctx0, block_1, eps); - block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b); - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3)); - // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1] - // hardswish - ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1); - - // not sure the parameters is right for globalAvgPooling - block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0); - // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] - // pointwise conv - block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]); - block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1); - block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b); - block_1 = ggml_relu(ctx0, block_1); - block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1); - block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b); - block_1 = ggml_hardsigmoid(ctx0, block_1); - - // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] - block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]); - block_1 = ggml_mul(ctx0, block_1_hw, block_1); - - int w = block_1->ne[0], h = block_1->ne[1]; - block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]); - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3)); - // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1] - block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1); - block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]); - - - // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1] - block_1 = ggml_norm(ctx0, block_1, eps); - block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b); - block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]); - // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1] - } - embeddings = block_1; - } - else if (ctx->proj_type() == PROJECTOR_TYPE_LDPV2) - { - int n_patch = 24; - ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); - mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b); - mlp_0 = ggml_gelu(ctx0, mlp_0); - ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0); - mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b); - // mlp_2 ne = [2048, 576, 1, 1] - // // AVG Pool Layer 2*2, strides = 2 - mlp_2 = ggml_permute(ctx0, mlp_2, 1, 0, 2, 3); - // mlp_2 ne = [576, 2048, 1, 1] - mlp_2 = ggml_cont_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); - // mlp_2 ne [24, 24, 2048, 1] - mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); - // weight ne = [3, 3, 2048, 1] - ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1); - peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3)); - peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b); - mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3)); - peg_0 = ggml_add(ctx0, peg_0, mlp_2); - peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]); - embeddings = peg_0; - } - else { - GGML_ABORT("fatal error"); - } - } - - // glm projector - else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) { - size_t gridsz = (size_t)sqrt(embeddings->ne[1]); - embeddings = ggml_permute(ctx0,embeddings,1,0,2,3); - embeddings = ggml_cont_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); - embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); - embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); - embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3)); - embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b); - // GLU - { - embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); - embeddings = ggml_norm(ctx0, embeddings, eps); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b); - embeddings = ggml_gelu_inplace(ctx0, embeddings); - ggml_tensor * x = embeddings; - embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); - x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); - embeddings = ggml_swiglu_split(ctx0, embeddings, x); - embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); - } - // arrangement of BOI/EOI token embeddings - // note: these embeddings are not present in text model, hence we cannot process them as text tokens - // see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53 - { - embeddings = ggml_concat(ctx0, model.mm_boi, embeddings, 1); // BOI - embeddings = ggml_concat(ctx0, embeddings, model.mm_eoi, 1); // EOI - } - } - - else { - GGML_ABORT("llava: unknown projector type"); - } - - // build the graph - ggml_build_forward_expand(gf, embeddings); - - return gf; - } - // whisper encoder with custom projector - ggml_cgraph * build_whisper_enc() { - const int n_frames = img.nx; - const int n_pos = n_frames / 2; - GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos); - - ggml_tensor * inp = build_inp_raw(1); - - // conv1d block - { - // convolution + gelu - ggml_tensor * cur = ggml_conv_1d_ph(ctx0, model.conv1d_1_w, inp, 1, 1); - cur = ggml_add(ctx0, cur, model.conv1d_1_b); - - cur = ggml_gelu_erf(ctx0, cur); - - cur = ggml_conv_1d_ph(ctx0, model.conv1d_2_w, cur, 2, 1); - cur = ggml_add(ctx0, cur, model.conv1d_2_b); - - cur = ggml_gelu_erf(ctx0, cur); - // transpose - inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); - cb(inp, "after_conv1d", -1); - } - - // sanity check (only check one layer, but it should be the same for all) - GGML_ASSERT(model.layers[0].ln_1_w && model.layers[0].ln_1_b); - GGML_ASSERT(model.layers[0].ln_2_w && model.layers[0].ln_2_b); - GGML_ASSERT(model.layers[0].q_b); - GGML_ASSERT(model.layers[0].v_b); - GGML_ASSERT(!model.layers[0].k_b); // no bias for k - GGML_ASSERT(model.post_ln_w && model.post_ln_b); - - ggml_tensor * pos_embd_selected = ggml_view_2d( - ctx0, model.position_embeddings, - model.position_embeddings->ne[0], n_pos, - model.position_embeddings->nb[1], 0 - ); - ggml_tensor * cur = build_vit( - inp, n_pos, - NORM_TYPE_NORMAL, - hparams.ffn_op, - pos_embd_selected, - nullptr); - - cb(cur, "after_transformer", -1); - - if (model.audio_has_stack_frames()) { - // StackAudioFrames - // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py - int64_t stride = n_embd * hparams.proj_stack_factor; - int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride); - int64_t pad = padded_len - ggml_nelements(cur); - if (pad > 0) { - cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0); - cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); - } - cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, - ggml_row_size(cur->type, stride), 0); - cb(cur, "after_stacked", -1); - } - - if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) { - // UltravoxProjector - // pre-norm - cur = ggml_rms_norm(ctx0, cur, 1e-6); - cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); - - // ffn in - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - - // swiglu - // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half - cur = ggml_swiglu_swapped(ctx0, cur); - - // mid-norm - cur = ggml_rms_norm(ctx0, cur, 1e-6); - cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w); - - // ffn out - cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); - - } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) { - // projector - cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur); - cur = ggml_add(ctx0, cur, model.mm_fc_b); - - } else if (ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL) { - // projector - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - cur = ggml_gelu_erf(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); - - } else { - GGML_ABORT("%s: unknown projector type", __func__); - } - - cb(cur, "projected", -1); - - ggml_build_forward_expand(gf, cur); - - return gf; - } - - // cogvlm vision encoder - ggml_cgraph * build_cogvlm() { - GGML_ASSERT(model.class_embedding != nullptr); - GGML_ASSERT(model.position_embeddings != nullptr); - - const int n_pos = n_patches + 1; // +1 for [CLS] - - // build input and concatenate class embedding - ggml_tensor * inp = build_inp(); - inp = ggml_concat(ctx0, inp, model.class_embedding, 1); - - inp = ggml_add(ctx0, inp, model.position_embeddings); - cb(inp, "inp_pos", -1); - - ggml_tensor * inpL = inp; - - for (int il = 0; il < n_layer; il++) { - auto & layer = model.layers[il]; - ggml_tensor * cur = inpL; - - cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); - - cur = ggml_add(ctx0, cur, layer.qkv_b); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), - cur->nb[1], 0); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), - cur->nb[1], n_embd * sizeof(float)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), - cur->nb[1], 2 * n_embd * sizeof(float)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - cur = build_attn(layer.o_w, layer.o_b, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); - cb(cur, "attn_out", il); - - cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); - cb(cur, "attn_post_norm", il); - - cur = ggml_add(ctx0, cur, inpL); - inpL = cur; - - cur = build_ffn(cur, - layer.ff_up_w, layer.ff_up_b, - layer.ff_gate_w, layer.ff_gate_b, - layer.ff_down_w, layer.ff_down_b, - hparams.ffn_op, il); - - cb(cur, "ffn_out", il); - - cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); - cb(cur, "ffn_post_norm", il); - - cur = ggml_add(ctx0, cur, inpL); - cb(cur, "layer_out", il); - inpL = cur; - - } - - // remove CLS token (like build_llama4 does) - ggml_tensor * cur = ggml_view_2d(ctx0, inpL, - n_embd, n_patches, - ggml_row_size(inpL->type, n_embd), 0); - - // Multiply with mm_model_proj - cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur); - - // Apply layernorm, weight, bias - cur = build_norm(cur, model.mm_post_fc_norm_w, model.mm_post_fc_norm_b, NORM_TYPE_NORMAL, 1e-5, -1); - - // Apply GELU - cur = ggml_gelu_inplace(ctx0, cur); - - // Branch 1: multiply with mm_h_to_4h_w - ggml_tensor * h_to_4h = ggml_mul_mat(ctx0, model.mm_h_to_4h_w, cur); - - // Branch 2: multiply with mm_gate_w - ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_gate_w, cur); - - // Apply silu - gate = ggml_swiglu_split(ctx0, gate, h_to_4h); - - // Apply mm_4h_to_h_w - cur = ggml_mul_mat(ctx0, model.mm_4h_to_h_w, gate); - - // Concatenate with boi and eoi - cur = ggml_concat(ctx0, model.mm_boi, cur, 1); - cur = ggml_concat(ctx0, cur, model.mm_eoi, 1); - - // build the graph - ggml_build_forward_expand(gf, cur); - - return gf; - } - -private: - // - // utility functions - // - - void cb(ggml_tensor * cur0, const char * name, int il) const { - if (ctx->debug_graph) { - ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0)); - std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name; - ggml_set_name(cur, cur_name.c_str()); - ggml_set_output(cur); - ggml_build_forward_expand(gf, cur); - ctx->debug_print_tensors.push_back(cur); - } - } - - // siglip2 naflex - ggml_tensor * resize_position_embeddings() { - ggml_tensor * pos_embd = model.position_embeddings; - const int height = img.ny / patch_size; - const int width = img.nx / patch_size; - const uint32_t mode = GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS; - const int n_per_side = (int)std::sqrt(pos_embd->ne[1]); - - GGML_ASSERT(pos_embd); - - if (height == n_per_side && width == n_per_side) { - return pos_embd; - } - - pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_per_side, n_per_side); // -> (n_embd, n_per_side, n_per_side) - pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_per_side, n_per_side, n_embd) - pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, mode); // -> (width, height, n_embd) - pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3); // -> (n_embd, width, height) - pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); // -> (n_embd, width * height) + GGML_ASSERT(pos_embd); + if (height == n_per_side && width == n_per_side) { return pos_embd; } - // build vision transformer (ViT) cgraph - // this function should cover most of the models - // if your model has specific features, you should probably duplicate this function - ggml_tensor * build_vit( - ggml_tensor * inp, - int64_t n_pos, - norm_type norm_t, - ffn_op_type ffn_t, - ggml_tensor * learned_pos_embd, - std::function add_pos - ) { - if (learned_pos_embd) { - inp = ggml_add(ctx0, inp, learned_pos_embd); - cb(inp, "pos_embed", -1); - } + pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_per_side, n_per_side); // -> (n_embd, n_per_side, n_per_side) + pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_per_side, n_per_side, n_embd) + pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, mode); // -> (width, height, n_embd) + pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3); // -> (n_embd, width, height) + pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); // -> (n_embd, width * height) - ggml_tensor * inpL = inp; + return pos_embd; +} - // pre-layernorm - if (model.pre_ln_w) { - inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); - cb(inpL, "pre_ln", -1); - } +// build vision transformer (ViT) cgraph +// this function should cover most of the models +// if your model has specific features, you should probably duplicate this function +ggml_tensor * clip_graph::build_vit( + ggml_tensor * inp, + int64_t n_pos, + norm_type norm_t, + ffn_op_type ffn_t, + ggml_tensor * learned_pos_embd, + std::function add_pos + ) { + if (learned_pos_embd) { + inp = ggml_add(ctx0, inp, learned_pos_embd); + cb(inp, "pos_embed", -1); + } - // loop over layers - for (int il = 0; il < n_layer; il++) { - auto & layer = model.layers[il]; - ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + ggml_tensor * inpL = inp; - // layernorm1 - cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); - cb(cur, "layer_inp_normed", il); + // pre-layernorm + if (model.pre_ln_w) { + inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); + cb(inpL, "pre_ln", -1); + } - // self-attention - { - ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); + // loop over layers + for (int il = 0; il < n_layer; il++) { + auto & layer = model.layers[il]; + ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); + cb(cur, "layer_inp_normed", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + if (layer.qkv_w != nullptr) { + // fused qkv + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + if (layer.qkv_b != nullptr) { + cur = ggml_add(ctx0, cur, layer.qkv_b); + } + + Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ 0); + + Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ ggml_row_size(cur->type, n_embd)); + + Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ ggml_row_size(cur->type, 2 * n_embd)); + + // TODO: q/k norm requires row size == n_embd, while here it's d_head + // we can add support in the future if needed + GGML_ASSERT(layer.q_norm == nullptr && layer.k_norm == nullptr); + + } else { + // separate q, k, v + Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); if (layer.q_b) { Qcur = ggml_add(ctx0, Qcur, layer.q_b); } - ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); + Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); if (layer.k_b) { Kcur = ggml_add(ctx0, Kcur, layer.k_b); } - ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); + Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); if (layer.v_b) { Vcur = ggml_add(ctx0, Vcur, layer.v_b); } @@ -2115,445 +393,478 @@ private: Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos); Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos); Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - if (add_pos) { - Qcur = add_pos(Qcur, layer); - Kcur = add_pos(Kcur, layer); - cb(Qcur, "Qcur_pos", il); - cb(Kcur, "Kcur_pos", il); - } - - cur = build_attn(layer.o_w, layer.o_b, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); - cb(cur, "attn_out", il); } - if (layer.ls_1_w) { - cur = ggml_mul(ctx0, cur, layer.ls_1_w); - cb(cur, "attn_out_scaled", il); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + if (add_pos) { + Qcur = add_pos(Qcur, layer); + Kcur = add_pos(Kcur, layer); + cb(Qcur, "Qcur_pos", il); + cb(Kcur, "Kcur_pos", il); } - // re-add the layer input, e.g., residual - cur = ggml_add(ctx0, cur, inpL); - - inpL = cur; // inpL = residual, cur = hidden_states - - cb(cur, "ffn_inp", il); - - // layernorm2 - cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); - cb(cur, "ffn_inp_normed", il); - - // ffn - cur = build_ffn(cur, - layer.ff_up_w, layer.ff_up_b, - layer.ff_gate_w, layer.ff_gate_b, - layer.ff_down_w, layer.ff_down_b, - ffn_t, il); - - cb(cur, "ffn_out", il); - - if (layer.ls_2_w) { - cur = ggml_mul(ctx0, cur, layer.ls_2_w); - cb(cur, "ffn_out_scaled", il); - } - - // residual 2 - cur = ggml_add(ctx0, inpL, cur); - cb(cur, "layer_out", il); - - inpL = cur; + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, nullptr, kq_scale, il); + cb(cur, "attn_out", il); } - if (ctx->model.audio_has_avgpool()) { - ggml_tensor * cur = inpL; - cur = ggml_transpose(ctx0, cur); - cur = ggml_cont(ctx0, cur); - cur = ggml_pool_1d(ctx0, cur, GGML_OP_POOL_AVG, 2, 2, 0); - cur = ggml_transpose(ctx0, cur); - cur = ggml_cont(ctx0, cur); - inpL = cur; + if (layer.ls_1_w) { + cur = ggml_mul(ctx0, cur, layer.ls_1_w); + cb(cur, "attn_out_scaled", il); } - // post-layernorm - if (model.post_ln_w) { - inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1); + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + cb(cur, "ffn_inp", il); + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); + cb(cur, "ffn_inp_normed", il); + + // ffn + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + layer.ff_gate_w, layer.ff_gate_b, + layer.ff_down_w, layer.ff_down_b, + ffn_t, il); + + cb(cur, "ffn_out", il); + + if (layer.ls_2_w) { + cur = ggml_mul(ctx0, cur, layer.ls_2_w); + cb(cur, "ffn_out_scaled", il); } - return inpL; + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + cb(cur, "layer_out", il); + + inpL = cur; } - // build the input after conv2d (inp_raw --> patches) - // returns tensor with shape [n_embd, n_patches] - ggml_tensor * build_inp() { - ggml_tensor * inp_raw = build_inp_raw(); - ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); - inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); - if (model.patch_bias) { - inp = ggml_add(ctx0, inp, model.patch_bias); - cb(inp, "patch_bias", -1); - } - return inp; + if (model.audio_has_avgpool()) { + ggml_tensor * cur = inpL; + cur = ggml_transpose(ctx0, cur); + cur = ggml_cont(ctx0, cur); + cur = ggml_pool_1d(ctx0, cur, GGML_OP_POOL_AVG, 2, 2, 0); + cur = ggml_transpose(ctx0, cur); + cur = ggml_cont(ctx0, cur); + inpL = cur; } - ggml_tensor * build_inp_raw(int channels = 3) { - ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels); - ggml_set_name(inp_raw, "inp_raw"); - ggml_set_input(inp_raw); - return inp_raw; + // post-layernorm + if (model.post_ln_w) { + inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1); } + return inpL; +} - ggml_tensor * build_norm( - ggml_tensor * cur, - ggml_tensor * mw, - ggml_tensor * mb, - norm_type type, - float norm_eps, - int il) const { - - cur = type == NORM_TYPE_RMS - ? ggml_rms_norm(ctx0, cur, norm_eps) - : ggml_norm(ctx0, cur, norm_eps); - - if (mw || mb) { - cb(cur, "norm", il); - } - - if (mw) { - cur = ggml_mul(ctx0, cur, mw); - if (mb) { - cb(cur, "norm_w", il); - } - } - - if (mb) { - cur = ggml_add(ctx0, cur, mb); - } - - return cur; +// build the input after conv2d (inp_raw --> patches) +// returns tensor with shape [n_embd, n_patches] +ggml_tensor * clip_graph::build_inp() { + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + cb(inp, "patch_bias", -1); } + return inp; +} - ggml_tensor * build_ffn( - ggml_tensor * cur, - ggml_tensor * up, - ggml_tensor * up_b, - ggml_tensor * gate, - ggml_tensor * gate_b, - ggml_tensor * down, - ggml_tensor * down_b, - ffn_op_type type_op, - int il) const { +ggml_tensor * clip_graph::build_inp_raw(int channels) { + ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + return inp_raw; +} - ggml_tensor * tmp = up ? ggml_mul_mat(ctx0, up, cur) : cur; - cb(tmp, "ffn_up", il); - - if (up_b) { - tmp = ggml_add(ctx0, tmp, up_b); - cb(tmp, "ffn_up_b", il); - } - - if (gate) { - cur = ggml_mul_mat(ctx0, gate, cur); - cb(cur, "ffn_gate", il); - - if (gate_b) { - cur = ggml_add(ctx0, cur, gate_b); - cb(cur, "ffn_gate_b", il); - } - } else { - cur = tmp; - } - - // we only support parallel ffn for now - switch (type_op) { - case FFN_SILU: - if (gate) { - cur = ggml_swiglu_split(ctx0, cur, tmp); - cb(cur, "ffn_swiglu", il); - } else { - cur = ggml_silu(ctx0, cur); - cb(cur, "ffn_silu", il); - } break; - case FFN_GELU: - if (gate) { - cur = ggml_geglu_split(ctx0, cur, tmp); - cb(cur, "ffn_geglu", il); - } else { - cur = ggml_gelu(ctx0, cur); - cb(cur, "ffn_gelu", il); - } break; - case FFN_GELU_ERF: - if (gate) { - cur = ggml_geglu_erf_split(ctx0, cur, tmp); - cb(cur, "ffn_geglu_erf", il); - } else { - cur = ggml_gelu_erf(ctx0, cur); - cb(cur, "ffn_gelu_erf", il); - } break; - case FFN_GELU_QUICK: - if (gate) { - cur = ggml_geglu_quick_split(ctx0, cur, tmp); - cb(cur, "ffn_geglu_quick", il); - } else { - cur = ggml_gelu_quick(ctx0, cur); - cb(cur, "ffn_gelu_quick", il); - } break; - } - - if (down) { - cur = ggml_mul_mat(ctx0, down, cur); - } - - if (down_b) { - cb(cur, "ffn_down", il); - } - - if (down_b) { - cur = ggml_add(ctx0, cur, down_b); - } - - return cur; - } - - ggml_tensor * build_attn( - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - ggml_tensor * kq_mask, - float kq_scale, - int il) const { - // these nodes are added to the graph together so that they are not reordered - // by doing so, the number of splits in the graph is reduced - ggml_build_forward_expand(gf, q_cur); - ggml_build_forward_expand(gf, k_cur); - ggml_build_forward_expand(gf, v_cur); - - ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); - //cb(q, "q", il); - - ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); - //cb(k, "k", il); - - ggml_tensor * cur; - - if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { - ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); - - k = ggml_cast(ctx0, k, GGML_TYPE_F16); - v = ggml_cast(ctx0, v, GGML_TYPE_F16); - - cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); - - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); - - } else { - ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); - v = ggml_cont(ctx0, v); - - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; - - ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - // F32 may not needed for vision encoders? - // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - - kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); - - ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); - } - - cb(cur, "kqv_out", il); - - if (wo) { - cur = ggml_mul_mat(ctx0, wo, cur); - } - - if (wo_b) { - cur = ggml_add(ctx0, cur, wo_b); - } - - return cur; - } - - // implementation of the 2D RoPE without adding a new op in ggml - // this is not efficient (use double the memory), but works on all backends - // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065 - static ggml_tensor * build_rope_2d( - ggml_context * ctx0, +ggml_tensor * clip_graph::build_norm( ggml_tensor * cur, - ggml_tensor * pos_a, // first half - ggml_tensor * pos_b, // second half - const float freq_base, - const bool interleave_freq - ) { - const int64_t n_dim = cur->ne[0]; - const int64_t n_head = cur->ne[1]; - const int64_t n_pos = cur->ne[2]; + ggml_tensor * mw, + ggml_tensor * mb, + norm_type type, + float norm_eps, + int il) const { - // for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos) - // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3 - // first half of cur will use 1e-0, 1e-2 (even) - // second half of cur will use 1e-1, 1e-3 (odd) - // the trick here is to rotate just half of n_dim, so inv_freq will automatically be even - // ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2) - // then for the second half, we use freq_scale to shift the inv_freq - // ^ why? replace (2i) with (2i+1) in the above equation - const float freq_scale_odd = interleave_freq - ? std::pow(freq_base, (float)-2/n_dim) - : 1.0; + cur = type == NORM_TYPE_RMS + ? ggml_rms_norm(ctx0, cur, norm_eps) + : ggml_norm(ctx0, cur, norm_eps); - // first half - ggml_tensor * first; - { - first = ggml_view_3d(ctx0, cur, - n_dim/2, n_head, n_pos, - ggml_row_size(cur->type, n_dim), - ggml_row_size(cur->type, n_dim*n_head), - 0); - first = ggml_rope_ext( - ctx0, - first, - pos_a, // positions - nullptr, // freq factors - n_dim/2, // n_dims - 0, 0, freq_base, - 1.0f, 0.0f, 1.0f, 0.0f, 0.0f - ); + if (mw) { + cur = ggml_mul(ctx0, cur, mw); + cb(cur, "norm_w", il); + } + + if (mb) { + cur = ggml_add(ctx0, cur, mb); + cb(cur, "norm_b", il); + } + + return cur; +} + +ggml_tensor * clip_graph::build_ffn( + ggml_tensor * cur, + ggml_tensor * up, + ggml_tensor * up_b, + ggml_tensor * gate, + ggml_tensor * gate_b, + ggml_tensor * down, + ggml_tensor * down_b, + ffn_op_type type_op, + int il) const { + + ggml_tensor * tmp = up ? ggml_mul_mat(ctx0, up, cur) : cur; + cb(tmp, "ffn_up", il); + + if (up_b) { + tmp = ggml_add(ctx0, tmp, up_b); + cb(tmp, "ffn_up_b", il); + } + + if (gate) { + cur = ggml_mul_mat(ctx0, gate, cur); + cb(cur, "ffn_gate", il); + + if (gate_b) { + cur = ggml_add(ctx0, cur, gate_b); + cb(cur, "ffn_gate_b", il); } + } else { + cur = tmp; + } - // second half - ggml_tensor * second; - { - second = ggml_view_3d(ctx0, cur, - n_dim/2, n_head, n_pos, - ggml_row_size(cur->type, n_dim), - ggml_row_size(cur->type, n_dim*n_head), - n_dim/2 * ggml_element_size(cur)); - second = ggml_rope_ext( - ctx0, - second, - pos_b, // positions - nullptr, // freq factors - n_dim/2, // n_dims - 0, 0, freq_base, - freq_scale_odd, - 0.0f, 1.0f, 0.0f, 0.0f - ); - } + // we only support parallel ffn for now + switch (type_op) { + case FFN_SILU: + if (gate) { + cur = ggml_swiglu_split(ctx0, cur, tmp); + cb(cur, "ffn_swiglu", il); + } else { + cur = ggml_silu(ctx0, cur); + cb(cur, "ffn_silu", il); + } break; + case FFN_GELU: + if (gate) { + cur = ggml_geglu_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu", il); + } else { + cur = ggml_gelu(ctx0, cur); + cb(cur, "ffn_gelu", il); + } break; + case FFN_GELU_ERF: + if (gate) { + cur = ggml_geglu_erf_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu_erf", il); + } else { + cur = ggml_gelu_erf(ctx0, cur); + cb(cur, "ffn_gelu_erf", il); + } break; + case FFN_GELU_QUICK: + if (gate) { + cur = ggml_geglu_quick_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu_quick", il); + } else { + cur = ggml_gelu_quick(ctx0, cur); + cb(cur, "ffn_gelu_quick", il); + } break; + } - cur = ggml_concat(ctx0, first, second, 0); + if (down) { + cur = ggml_mul_mat(ctx0, down, cur); + } + + if (down_b) { + cb(cur, "ffn_down", il); + } + + if (down_b) { + cur = ggml_add(ctx0, cur, down_b); + } + + return cur; +} + +ggml_tensor * clip_graph::build_attn( + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_mask, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + //cb(q, "q", il); + + ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); + //cb(k, "k", il); + + ggml_tensor * cur; + + if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { + ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); + + k = ggml_cast(ctx0, k, GGML_TYPE_F16); + v = ggml_cast(ctx0, v, GGML_TYPE_F16); + + cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, 0.0f, 0.0f); + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); + + } else { + ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); + v = ggml_cont(ctx0, v); + + const auto n_tokens = q->ne[1]; + const auto n_head = q->ne[2]; + + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + // F32 may not needed for vision encoders? + // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + } + + cb(cur, "kqv_out", il); + + if (wo) { + cur = ggml_mul_mat(ctx0, wo, cur); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +// implementation of the 2D RoPE without adding a new op in ggml +// this is not efficient (use double the memory), but works on all backends +// TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065 +ggml_tensor * clip_graph::build_rope_2d( + ggml_context * ctx0, + ggml_tensor * cur, + ggml_tensor * pos_a, // first half + ggml_tensor * pos_b, // second half + const float freq_base, + const bool interleave_freq +) { + const int64_t n_dim = cur->ne[0]; + const int64_t n_head = cur->ne[1]; + const int64_t n_pos = cur->ne[2]; + + // for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos) + // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3 + // first half of cur will use 1e-0, 1e-2 (even) + // second half of cur will use 1e-1, 1e-3 (odd) + // the trick here is to rotate just half of n_dim, so inv_freq will automatically be even + // ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2) + // then for the second half, we use freq_scale to shift the inv_freq + // ^ why? replace (2i) with (2i+1) in the above equation + const float freq_scale_odd = interleave_freq + ? std::pow(freq_base, (float)-2/n_dim) + : 1.0; + + // first half + ggml_tensor * first; + { + first = ggml_view_3d(ctx0, cur, + n_dim/2, n_head, n_pos, + ggml_row_size(cur->type, n_dim), + ggml_row_size(cur->type, n_dim*n_head), + 0); + first = ggml_rope_ext( + ctx0, + first, + pos_a, // positions + nullptr, // freq factors + n_dim/2, // n_dims + 0, 0, freq_base, + 1.0f, 0.0f, 1.0f, 0.0f, 0.0f + ); + } + + // second half + ggml_tensor * second; + { + second = ggml_view_3d(ctx0, cur, + n_dim/2, n_head, n_pos, + ggml_row_size(cur->type, n_dim), + ggml_row_size(cur->type, n_dim*n_head), + n_dim/2 * ggml_element_size(cur)); + second = ggml_rope_ext( + ctx0, + second, + pos_b, // positions + nullptr, // freq factors + n_dim/2, // n_dims + 0, 0, freq_base, + freq_scale_odd, + 0.0f, 1.0f, 0.0f, 0.0f + ); + } + + cur = ggml_concat(ctx0, first, second, 0); + return cur; +} + +// Generic function to stack frames for audio processing +// Abstracts out the StackAudioFrames logic used by ultravox +ggml_tensor * clip_graph::build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed) { + if (stack_factor <= 1) { return cur; } - // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL) - // support dynamic resolution - ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor) { - GGML_ASSERT(scale_factor > 1); + int64_t total_elements = ggml_nelements(cur); + int64_t stride = n_embed * stack_factor; - const int n_embd = cur->ne[0]; - int width = img.nx / patch_size; - int height = img.ny / patch_size; + // Calculate padded length + int64_t padded_len = GGML_PAD(total_elements, stride); + int64_t pad = padded_len - total_elements; - // pad width and height to factor - const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width; - const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height; - cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height); - if (pad_width || pad_height) { - cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0); - width += pad_width; - height += pad_height; - } - - // unshuffle h - cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - - // unshuffle w - cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - - cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); - cb(cur, "pixel_shuffle", -1); - - return cur; + if (pad > 0) { + // Pad the tensor to make it divisible by stride + cur = ggml_view_1d(ctx0, cur, total_elements, 0); + cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); } -}; + // Reshape to [stride, padded_len / stride] + cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, + ggml_row_size(cur->type, stride), 0); + return cur; +} + +// aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL) +// support dynamic resolution +ggml_tensor * clip_graph::build_patch_merge_permute(ggml_tensor * cur, int scale_factor) { + GGML_ASSERT(scale_factor > 1); + + const int n_embd = cur->ne[0]; + int width = img.nx / patch_size; + int height = img.ny / patch_size; + + // pad width and height to factor + const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width; + const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height; + cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height); + if (pad_width || pad_height) { + cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0); + width += pad_width; + height += pad_height; + } + + // unshuffle h + cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + + // unshuffle w + cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + + cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); + cb(cur, "pixel_shuffle", -1); + + return cur; +} static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) { GGML_ASSERT(imgs.entries.size() == 1 && "n_batch > 1 is not supported"); - clip_graph graph(ctx, *imgs.entries[0]); - ggml_cgraph * res; + const clip_image_f32 & img = *imgs.entries[0]; + std::unique_ptr builder; switch (ctx->proj_type()) { case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_LFM2: + case PROJECTOR_TYPE_JANUS_PRO: { - res = graph.build_siglip(); + builder = std::make_unique(ctx, img); } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: { - res = graph.build_pixtral(); + builder = std::make_unique(ctx, img); } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: { - res = graph.build_qwen2vl(); + builder = std::make_unique(ctx, img); } break; case PROJECTOR_TYPE_QWEN3VL: { - res = graph.build_qwen3vl(); + builder = std::make_unique(ctx, img); } break; case PROJECTOR_TYPE_MINICPMV: { - res = graph.build_minicpmv(); + builder = std::make_unique(ctx, img); } break; case PROJECTOR_TYPE_INTERNVL: { - res = graph.build_internvl(); + builder = std::make_unique(ctx, img); } break; case PROJECTOR_TYPE_LLAMA4: { - res = graph.build_llama4(); + builder = std::make_unique(ctx, img); } break; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_GLMA: { - res = graph.build_whisper_enc(); + builder = std::make_unique(ctx, img); } break; case PROJECTOR_TYPE_KIMIVL: { - res = graph.build_kimivl(); - } break; - case PROJECTOR_TYPE_JANUS_PRO: - { - res = graph.build_siglip(); + builder = std::make_unique(ctx, img); } break; case PROJECTOR_TYPE_COGVLM: { - res = graph.build_cogvlm(); + builder = std::make_unique(ctx, img); + } break; + case PROJECTOR_TYPE_MLP: + case PROJECTOR_TYPE_MLP_NORM: + case PROJECTOR_TYPE_LDP: + case PROJECTOR_TYPE_LDPV2: + case PROJECTOR_TYPE_GLM_EDGE: + { + builder = std::make_unique(ctx, img); + } break; + case PROJECTOR_TYPE_GLM4V: + { + builder = std::make_unique(ctx, img); } break; default: - { - res = graph.build_llava(); - } break; + GGML_ABORT("missing cgraph builder"); } - return res; + + return builder->build(); } +// +// clip_model_loader +// + struct clip_model_loader { ggml_context_ptr ctx_meta; gguf_context_ptr ctx_gguf; @@ -2860,6 +1171,14 @@ struct clip_model_loader { LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); } } break; + case PROJECTOR_TYPE_GLM4V: + { + hparams.rope_theta = 10000.0f; + hparams.n_merge = 2; // default value for GLM4-V + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + hparams.set_limit_image_tokens(8, 4096); + hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_LLAMA4: { hparams.rope_theta = 10000.0f; @@ -2868,16 +1187,22 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_VOXTRAL: { bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX || - model.proj_type == PROJECTOR_TYPE_VOXTRAL; + model.proj_type == PROJECTOR_TYPE_VOXTRAL || + model.proj_type == PROJECTOR_TYPE_GLMA; get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack); - if (hparams.n_mel_bins != 128) { - throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__)); - } hparams.ffn_op = FFN_GELU_ERF; log_ffn_op = "gelu_erf"; // temporary solution for logging + + // audio preprocessing params + hparams.audio_chunk_len = 30; // in seconds + hparams.audio_sample_rate = 16000; + hparams.audio_n_fft = 400; + hparams.audio_window_len = 400; + hparams.audio_hop_len = 160; } break; default: break; @@ -2915,6 +1240,11 @@ struct clip_model_loader { LOG_INF("\n--- audio hparams ---\n"); LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins); LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor); + LOG_INF("%s: audio_chunk_len: %d\n", __func__, hparams.audio_chunk_len); + LOG_INF("%s: audio_sample_rate: %d\n", __func__, hparams.audio_sample_rate); + LOG_INF("%s: audio_n_fft: %d\n", __func__, hparams.audio_n_fft); + LOG_INF("%s: audio_window_len: %d\n", __func__, hparams.audio_window_len); + LOG_INF("%s: audio_hop_len: %d\n", __func__, hparams.audio_hop_len); } LOG_INF("\n"); LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); @@ -2976,6 +1306,9 @@ struct clip_model_loader { model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false); model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false); + model.norm_embd_w = get_tensor(string_format(TN_NORM_EMBD, "weight"), false); + model.norm_embd_b = get_tensor(string_format(TN_NORM_EMBD, "bias"), false); + model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false); // layers @@ -3164,6 +1497,20 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_GLM4V: + { + model.projection = get_tensor(TN_MM_PROJECTOR); + model.mm_ffn_up_w = get_tensor(string_format(TN_MM_UP, "weight")); + model.mm_ffn_up_b = get_tensor(string_format(TN_MM_UP, "bias"), false); + model.mm_ffn_gate_w = get_tensor(string_format(TN_MM_GATE, "weight")); + model.mm_ffn_gate_b = get_tensor(string_format(TN_MM_GATE, "bias"), false); + model.mm_ffn_down_w = get_tensor(string_format(TN_MM_DOWN, "weight")); + model.mm_ffn_down_b = get_tensor(string_format(TN_MM_DOWN, "bias"), false); + model.mm_post_norm_w = get_tensor(string_format(TN_MM_POST_NORM, "weight")); + model.mm_post_norm_b = get_tensor(string_format(TN_MM_POST_NORM, "bias"), false); + model.mm_patch_merger_w = get_tensor(string_format(TN_MM_PATCH_MERGER, "weight")); + model.mm_patch_merger_b = get_tensor(string_format(TN_MM_PATCH_MERGER, "bias")); + } break; case PROJECTOR_TYPE_GEMMA3: { model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); @@ -3192,8 +1539,8 @@ struct clip_model_loader { // [IMG_BREAK] token embedding model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK); // for mistral small 3.1 - model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); - model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + model.mm_patch_merger_w = get_tensor(string_format(TN_MM_PATCH_MERGER, "weight"), false); } break; case PROJECTOR_TYPE_LIGHTONOCR: { @@ -3201,8 +1548,8 @@ struct clip_model_loader { model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); - model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); - model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + model.mm_patch_merger_w = get_tensor(string_format(TN_MM_PATCH_MERGER, "weight"), false); } break; case PROJECTOR_TYPE_ULTRAVOX: { @@ -3242,6 +1589,21 @@ struct clip_model_loader { model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); } break; + case PROJECTOR_TYPE_GLMA: + { + model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); + model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); + model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); + model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); + model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias")); + model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias")); + model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight")); + model.mm_norm_pre_b = get_tensor(string_format(TN_MM_NORM_PRE, "bias")); + model.mm_boi = get_tensor(string_format(TN_TOK_BOI, "weight")); + model.mm_eoi = get_tensor(string_format(TN_TOK_EOI, "weight")); + } break; case PROJECTOR_TYPE_LLAMA4: { model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); @@ -3578,6 +1940,8 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params if (ctx_params.warmup) { loader.warmup(*ctx_vision); } + + // clip_debug_encode(ctx_vision, 24*14, 24*14, 0.5f); } if (loader.has_audio) { @@ -3988,7 +2352,14 @@ struct llava_uhd { clip_image_size refined_size; // size of image right before slicing (must be multiple of slice size) clip_image_size grid_size; // grid_size.width * grid_size.height = number of slices std::vector slices; + + img_tool::resize_algo interpolation_overview = img_tool::RESIZE_ALGO_BILINEAR; + bool padding_overview = false; // if true, refine image will be padded to the grid size (e.g. llava-1.6) + std::array pad_color_overview = {0, 0, 0}; + + img_tool::resize_algo interpolation_refined = img_tool::RESIZE_ALGO_BICUBIC; bool padding_refined = false; // if true, refine image will be padded to the grid size (e.g. llava-1.6) + std::array pad_color_refined = {0, 0, 0}; }; static slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) { @@ -4015,10 +2386,11 @@ struct llava_uhd { auto refine_size = llava_uhd::select_best_resolution( original_size, ctx->model.hparams.image_res_candidates); - res.overview_size = clip_image_size{slice_size, slice_size}; - res.refined_size = refine_size; - res.grid_size = clip_image_size{0, 0}; - res.padding_refined = true; + res.overview_size = clip_image_size{slice_size, slice_size}; + res.refined_size = refine_size; + res.grid_size = clip_image_size{0, 0}; + res.padding_refined = true; + res.interpolation_refined = img_tool::RESIZE_ALGO_BILINEAR; // preserve old behavior when padding LOG_DBG("%s: using pinpoints for slicing\n", __func__); LOG_DBG("%s: original size: %d x %d, overview size: %d x %d, refined size: %d x %d\n", @@ -4097,12 +2469,13 @@ struct llava_uhd { static std::vector slice_image(const clip_image_u8 * img, const slice_instructions & inst) { std::vector output; - img_tool::resize_algo interpolation = img_tool::RESIZE_ALGO_BILINEAR; // TODO: make it configurable // resize to overview size clip_image_u8_ptr resized_img(clip_image_u8_init()); - img_tool::resize(*img, *resized_img, inst.overview_size, interpolation); + img_tool::resize(*img, *resized_img, inst.overview_size, inst.interpolation_overview, + inst.padding_overview, inst.pad_color_overview); output.push_back(std::move(resized_img)); + if (inst.slices.empty()) { // no slices, just return the resized image return output; @@ -4110,13 +2483,8 @@ struct llava_uhd { // resize to refined size clip_image_u8_ptr refined_img(clip_image_u8_init()); - if (inst.padding_refined) { - img_tool::resize(*img, *refined_img, inst.refined_size, interpolation); - } else { - // only algo bicubic preserves the ratio; old models rely on this behavior - // TODO: do we need to support other algos here? - img_tool::resize(*img, *refined_img, inst.refined_size, img_tool::RESIZE_ALGO_BICUBIC, false); - } + img_tool::resize(*img, *refined_img, inst.refined_size, inst.interpolation_refined, + inst.padding_refined, inst.pad_color_refined); // create slices for (const auto & slice : inst.slices) { @@ -4283,6 +2651,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_GLM4V: { GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); clip_image_u8 resized; @@ -4525,16 +2894,30 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) { int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->model.hparams; const int n_total = clip_n_output_tokens(ctx, img); - if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) { - return img->nx / (params.patch_size * 2); + const auto & proj = ctx->proj_type(); + switch (proj) { + case PROJECTOR_TYPE_QWEN2VL: + case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_GLM4V: + return (img->nx / params.patch_size) / 2; + default: + break; } return n_total; } int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->model.hparams; - if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) { - return img->ny / (params.patch_size * 2); + const auto & proj = ctx->proj_type(); + switch (proj) { + case PROJECTOR_TYPE_QWEN2VL: + case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_GLM4V: + return (img->ny / params.patch_size) / 2; + default: + break; } return 1; } @@ -4591,6 +2974,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_GLM4V: { // dynamic size (2 conv, so double patch size) int x_patch = img->nx / (params.patch_size * 2); @@ -4649,6 +3033,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches /= 2; } } break; + case PROJECTOR_TYPE_GLMA: + { + n_patches = img->nx; + // whisper downscales input token by half after conv1d + n_patches /= 2; + // reshape by merge_factor + n_patches /= ctx->model.hparams.proj_stack_factor; + // for BOI and EOI token embeddings + n_patches += 2; + } break; case PROJECTOR_TYPE_COGVLM: { n_patches += 2; // for BOI and EOI token embeddings @@ -4828,6 +3222,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_GLM4V: { const int merge_ratio = hparams.n_merge; const int pw = image_size_width / patch_size; @@ -4984,6 +3379,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_VOXTRAL: @@ -5053,7 +3449,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } // copy the embeddings to the location passed by the user - ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + if (vec != nullptr) { + ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + } return true; } @@ -5094,11 +3492,15 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_model_proj->ne[1]; case PROJECTOR_TYPE_QWEN2A: return ctx->model.mm_fc_w->ne[1]; + case PROJECTOR_TYPE_GLMA: + return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; + case PROJECTOR_TYPE_GLM4V: + return ctx->model.mm_ffn_down_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } @@ -5115,10 +3517,11 @@ bool clip_is_glm(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE; } -bool clip_is_qwen2vl(const struct clip_ctx * ctx) { +bool clip_is_mrope(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL - || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL; + || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL + || ctx->proj_type() == PROJECTOR_TYPE_GLM4V; } bool clip_is_llava(const struct clip_ctx * ctx) { @@ -5140,6 +3543,7 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) { bool clip_has_whisper_encoder(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A + || ctx->proj_type() == PROJECTOR_TYPE_GLMA || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL; } @@ -5174,3 +3578,26 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel batch->entries.push_back(clip_image_f32_ptr(audio)); batch->is_audio = true; } + +const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx) { + return &ctx->model.hparams; +} + +// +// API for debugging +// + +void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value) { + clip_image_f32 img; + img.nx = w; + img.ny = h; + img.buf.resize(h * w * 3); + for (int i = 0; i < h * w * 3; i++) { + img.buf[i] = static_cast(fill_value); + } + bool cur_debug_graph = ctx->debug_graph; + ctx->debug_graph = true; + clip_image_encode(ctx, 1, &img, nullptr); + ctx->debug_graph = cur_debug_graph; + GGML_ASSERT(img.buf.empty() && "expected, always stop here"); +} diff --git a/llama/llama.cpp/tools/mtmd/clip.h b/llama/llama.cpp/tools/mtmd/clip.h index e8aeb206..68a0d6e8 100644 --- a/llama/llama.cpp/tools/mtmd/clip.h +++ b/llama/llama.cpp/tools/mtmd/clip.h @@ -7,6 +7,8 @@ // !!! Internal header, to be used by mtmd only !!! +#define MTMD_INTERNAL_HEADER + struct clip_ctx; struct clip_image_size { @@ -102,7 +104,7 @@ bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct int clip_is_minicpmv(const struct clip_ctx * ctx); bool clip_is_glm(const struct clip_ctx * ctx); -bool clip_is_qwen2vl(const struct clip_ctx * ctx); +bool clip_is_mrope(const struct clip_ctx * ctx); bool clip_is_llava(const struct clip_ctx * ctx); bool clip_is_gemma3(const struct clip_ctx * ctx); diff --git a/llama/llama.cpp/tools/mtmd/models/cogvlm.cpp b/llama/llama.cpp/tools/mtmd/models/cogvlm.cpp new file mode 100644 index 00000000..d5b739c6 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/cogvlm.cpp @@ -0,0 +1,98 @@ +#include "models.h" + +ggml_cgraph * clip_graph_cogvlm::build() { + GGML_ASSERT(model.class_embedding != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + + const int n_pos = n_patches + 1; // +1 for [CLS] + + // build input and concatenate class embedding + ggml_tensor * inp = build_inp(); + inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + + inp = ggml_add(ctx0, inp, model.position_embeddings); + cb(inp, "inp_pos", -1); + + ggml_tensor * inpL = inp; + + for (int il = 0; il < n_layer; il++) { + auto & layer = model.layers[il]; + ggml_tensor * cur = inpL; + + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + + cur = ggml_add(ctx0, cur, layer.qkv_b); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), + cur->nb[1], 0); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), + cur->nb[1], n_embd * sizeof(float)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float), + cur->nb[1], 2 * n_embd * sizeof(float)); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); + cb(cur, "attn_post_norm", il); + + cur = ggml_add(ctx0, cur, inpL); + inpL = cur; + + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + layer.ff_gate_w, layer.ff_gate_b, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + + cb(cur, "ffn_out", il); + + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "layer_out", il); + inpL = cur; + + } + + // remove CLS token (like build_llama4 does) + ggml_tensor * cur = ggml_view_2d(ctx0, inpL, + n_embd, n_patches, + ggml_row_size(inpL->type, n_embd), 0); + + // Multiply with mm_model_proj + cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur); + + // Apply layernorm, weight, bias + cur = build_norm(cur, model.mm_post_fc_norm_w, model.mm_post_fc_norm_b, NORM_TYPE_NORMAL, 1e-5, -1); + + // Apply GELU + cur = ggml_gelu_inplace(ctx0, cur); + + // Branch 1: multiply with mm_h_to_4h_w + ggml_tensor * h_to_4h = ggml_mul_mat(ctx0, model.mm_h_to_4h_w, cur); + + // Branch 2: multiply with mm_gate_w + ggml_tensor * gate = ggml_mul_mat(ctx0, model.mm_gate_w, cur); + + // Apply silu + gate = ggml_swiglu_split(ctx0, gate, h_to_4h); + + // Apply mm_4h_to_h_w + cur = ggml_mul_mat(ctx0, model.mm_4h_to_h_w, gate); + + // Concatenate with boi and eoi + cur = ggml_concat(ctx0, model.mm_boi, cur, 1); + cur = ggml_concat(ctx0, cur, model.mm_eoi, 1); + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/glm4v.cpp b/llama/llama.cpp/tools/mtmd/models/glm4v.cpp new file mode 100644 index 00000000..f39b6922 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/glm4v.cpp @@ -0,0 +1,120 @@ +#include "models.h" + +ggml_cgraph * clip_graph_glm4v::build() { + GGML_ASSERT(model.patch_bias != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + + const int batch_size = 1; + + norm_type norm_t = NORM_TYPE_RMS; + + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches * 4); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + GGML_ASSERT(img.nx % (patch_size * 2) == 0); + GGML_ASSERT(img.ny % (patch_size * 2) == 0); + + // second conv dimension + { + auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_cont_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + inp = ggml_reshape_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d( + ctx0, inp, + n_embd, n_patches_x * n_patches_y, batch_size); + } + + // add patch bias + inp = ggml_add(ctx0, inp, model.patch_bias); + cb(inp, "patch_bias", -1); + + // pos-conv norm + inp = build_norm(inp, model.norm_embd_w, model.norm_embd_b, norm_t, eps, -1); + + // calculate absolute position embedding and apply + ggml_tensor * learned_pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BICUBIC); + learned_pos_embd = ggml_cont_4d( + ctx0, learned_pos_embd, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + learned_pos_embd = ggml_reshape_4d( + ctx0, learned_pos_embd, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3); + learned_pos_embd = ggml_cont_3d( + ctx0, learned_pos_embd, + n_embd, n_patches_x * n_patches_y, batch_size); + cb(learned_pos_embd, "learned_pos_embd", -1); + + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + return ggml_rope_multi( + ctx0, cur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, + 32768, hparams.rope_theta, 1, 0, 1, 32, 1); + }; + + ggml_tensor * cur = build_vit( + inp, n_patches, + norm_t, + hparams.ffn_op, + learned_pos_embd, + add_pos); + + cb(cur, "vit_out", -1); + // cb(ggml_sum(ctx0, cur), "vit_out_sum", -1); + + // GLM4V projector + // ref: https://github.com/huggingface/transformers/blob/40dc11cd3eb4126652aa41ef8272525affd4a636/src/transformers/models/glm4v/modeling_glm4v.py#L116-L130 + + // patch merger (downsample) + { + int n_merge = hparams.n_merge; + GGML_ASSERT(n_merge > 0); + + int n_token_out = n_patches / n_merge / n_merge; + cur = ggml_reshape_4d(ctx0, cur, n_embd, n_merge, n_merge, n_token_out); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); // [n_merge, n_merge, n_embd, n_token_out] + cur = ggml_conv_2d(ctx0, model.mm_patch_merger_w, cur, n_merge, n_merge, 0, 0, 1, 1); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[2], n_token_out); // [n_embd_out, n_token_out] + + cur = ggml_add(ctx0, cur, model.mm_patch_merger_b); + } + + // FC projector + { + cur = ggml_mul_mat(ctx0, model.projection, cur); + // default LayerNorm (post_projection_norm) + cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1); + cur = ggml_gelu_erf(ctx0, cur); + cb(cur, "after_fc_proj", -1); + } + + // FFN projector + { + cur = build_ffn(cur, + model.mm_ffn_up_w, model.mm_ffn_up_b, + model.mm_ffn_gate_w, model.mm_ffn_gate_b, + model.mm_ffn_down_w, model.mm_ffn_down_b, + hparams.ffn_op, -1); + cb(cur, "after_ffn_proj", -1); + // cb(ggml_sum(ctx0, cur), "merged_sum", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/internvl.cpp b/llama/llama.cpp/tools/mtmd/models/internvl.cpp new file mode 100644 index 00000000..9aded3b9 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/internvl.cpp @@ -0,0 +1,69 @@ +#include "models.h" + +ggml_cgraph * clip_graph_internvl::build() { + GGML_ASSERT(model.class_embedding != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + + const int n_pos = n_patches + 1; + ggml_tensor * inp = build_inp(); + + // add CLS token + inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + + // The larger models use a different ViT, which uses RMS norm instead of layer norm + // ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188 + norm_type norm_t = (hparams.n_embd == 3200 && hparams.n_layer == 45) + ? NORM_TYPE_RMS // 6B ViT (Used by InternVL 2.5/3 - 26B, 38B, 78B) + : NORM_TYPE_NORMAL; // 300M ViT (Used by all smaller InternVL models) + + ggml_tensor * cur = build_vit( + inp, n_pos, + norm_t, + hparams.ffn_op, + model.position_embeddings, + nullptr); + + // remove CLS token + cur = ggml_view_2d(ctx0, cur, + n_embd, n_patches, + ggml_row_size(cur->type, n_embd), 0); + + // pixel shuffle + { + const int scale_factor = model.hparams.n_merge; + const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int height = n_patches_y; + const int width = n_patches_x; + GGML_ASSERT(scale_factor > 0); + cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_cont_4d(ctx0, cur, + n_embd * scale_factor * scale_factor, + height / scale_factor, + width / scale_factor, + bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // flatten to 2D + cur = ggml_cont_2d(ctx0, cur, + n_embd * scale_factor * scale_factor, + cur->ne[1] * cur->ne[2]); + } + + // projector (always using GELU activation) + { + // projector LayerNorm uses pytorch's default eps = 1e-5 + // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79 + cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1); + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_3_w, model.mm_3_b, + FFN_GELU, + -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/kimivl.cpp b/llama/llama.cpp/tools/mtmd/models/kimivl.cpp new file mode 100644 index 00000000..0a06f509 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/kimivl.cpp @@ -0,0 +1,63 @@ +#include "models.h" + +ggml_cgraph * clip_graph_kimivl::build() { + // 2D input positions + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + ggml_tensor * learned_pos_embd = resize_position_embeddings(); + + // build ViT with 2D position embeddings + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + // first half is X axis and second half is Y axis + return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + }; + + ggml_tensor * inp = build_inp(); + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_NORMAL, + hparams.ffn_op, + learned_pos_embd, + add_pos); + + cb(cur, "vit_out", -1); + + { + // patch_merger + const int scale_factor = model.hparams.n_merge; + cur = build_patch_merge_permute(cur, scale_factor); + + // projection norm + int proj_inp_dim = cur->ne[0]; + cur = ggml_view_2d(ctx0, cur, + n_embd, cur->ne[1] * scale_factor * scale_factor, + ggml_row_size(cur->type, n_embd), 0); + cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm + cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); + cur = ggml_add(ctx0, cur, model.mm_input_norm_b); + cur = ggml_view_2d(ctx0, cur, + proj_inp_dim, cur->ne[1] / scale_factor / scale_factor, + ggml_row_size(cur->type, proj_inp_dim), 0); + cb(cur, "proj_inp_normed", -1); + + // projection mlp + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU, + -1); + cb(cur, "proj_out", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/llama4.cpp b/llama/llama.cpp/tools/mtmd/models/llama4.cpp new file mode 100644 index 00000000..30d1df5b --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/llama4.cpp @@ -0,0 +1,96 @@ +#include "models.h" + +ggml_cgraph * clip_graph_llama4::build() { + GGML_ASSERT(model.class_embedding != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + + const int n_pos = n_patches + 1; // +1 for [CLS] + + // 2D input positions + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + ggml_tensor * inp = build_inp_raw(); + + // Llama4UnfoldConvolution + { + ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0, + patch_size, patch_size, 3, n_embd); + inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type); + inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); + inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); + cb(inp, "patch_conv", -1); + } + + // add CLS token + inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + + // build ViT with 2D position embeddings + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + // first half is X axis and second half is Y axis + // ref: https://github.com/huggingface/transformers/blob/40a493c7ed4f19f08eadb0639cf26d49bfa5e180/src/transformers/models/llama4/modeling_llama4.py#L1312 + // ref: https://github.com/Blaizzy/mlx-vlm/blob/a57156aa87b33cca6e5ee6cfc14dd4ef8f611be6/mlx_vlm/models/llama4/vision.py#L441 + return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + }; + ggml_tensor * cur = build_vit( + inp, n_pos, + NORM_TYPE_NORMAL, + hparams.ffn_op, + model.position_embeddings, + add_pos); + + // remove CLS token + cur = ggml_view_2d(ctx0, cur, + n_embd, n_patches, + ggml_row_size(cur->type, n_embd), 0); + + // pixel shuffle + // based on Llama4VisionPixelShuffleMLP + // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151 + { + const int scale_factor = model.hparams.n_merge; + const int bsz = 1; // batch size, always 1 for now since we don't support batching + GGML_ASSERT(scale_factor > 0); + GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images + cur = ggml_reshape_4d(ctx0, cur, + n_embd * scale_factor, + n_patches_x / scale_factor, + n_patches_y, + bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_cont_4d(ctx0, cur, + n_embd * scale_factor * scale_factor, + n_patches_x / scale_factor, + n_patches_y / scale_factor, + bsz); + //cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // flatten to 2D + cur = ggml_cont_2d(ctx0, cur, + n_embd * scale_factor * scale_factor, + n_patches / scale_factor / scale_factor); + cb(cur, "pixel_shuffle", -1); + } + + // based on Llama4VisionMLP2 (always uses GELU activation, no bias) + { + cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur); + cur = ggml_gelu(ctx0, cur); + cb(cur, "adapter_mlp", -1); + } + + // Llama4MultiModalProjector + cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur); + cb(cur, "projected", -1); + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/llava.cpp b/llama/llama.cpp/tools/mtmd/models/llava.cpp new file mode 100644 index 00000000..0bfb5f05 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/llava.cpp @@ -0,0 +1,374 @@ +#include "models.h" + +// this graph is used by llava, granite and glm +// due to having embedding_stack (used by granite), we cannot reuse build_vit +ggml_cgraph * clip_graph_llava::build() { + const int batch_size = 1; + const int n_pos = n_patches + (model.class_embedding ? 1 : 0); + + GGML_ASSERT(n_patches_x == n_patches_y && "only square images supported"); + + // Calculate the deepest feature layer based on hparams and projector type + int max_feature_layer = n_layer; + { + // Get the index of the second to last layer; this is the default for models that have a llava projector + int il_last = hparams.n_layer - 1; + int deepest_feature_layer = -1; + + if (proj_type == PROJECTOR_TYPE_MINICPMV || proj_type == PROJECTOR_TYPE_GLM_EDGE) { + il_last += 1; + } + + // If we set explicit vision feature layers, only go up to the deepest one + // NOTE: only used by granite-vision models for now + for (const auto & feature_layer : hparams.vision_feature_layer) { + if (feature_layer > deepest_feature_layer) { + deepest_feature_layer = feature_layer; + } + } + max_feature_layer = deepest_feature_layer < 0 ? il_last : deepest_feature_layer; + } + + ggml_tensor * inp = build_inp(); + + // concat class_embeddings and patch_embeddings + if (model.class_embedding) { + inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + } + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + inp = ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions)); + + ggml_tensor * inpL = inp; + + // pre-layernorm + if (model.pre_ln_w) { + inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, NORM_TYPE_NORMAL, eps, -1); + cb(inpL, "pre_ln", -1); + } + + std::vector embedding_stack; + const auto & vision_feature_layer = hparams.vision_feature_layer; + + // loop over layers + for (int il = 0; il < max_feature_layer; il++) { + auto & layer = model.layers[il]; + ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + + // If this is an embedding feature layer, save the output. + // NOTE: 0 index here refers to the input to the encoder. + if (vision_feature_layer.find(il) != vision_feature_layer.end()) { + embedding_stack.push_back(cur); + } + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); + cb(cur, "layer_inp_normed", il); + + // self-attention + { + ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); + if (layer.q_b) { + Qcur = ggml_add(ctx0, Qcur, layer.q_b); + } + + ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); + if (layer.k_b) { + Kcur = ggml_add(ctx0, Kcur, layer.k_b); + } + + ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); + if (layer.v_b) { + Vcur = ggml_add(ctx0, Vcur, layer.v_b); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos); + Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos); + Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + cb(cur, "ffn_inp", il); + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); + cb(cur, "ffn_inp_normed", il); + + // ffn + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + layer.ff_gate_w, layer.ff_gate_b, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + + cb(cur, "ffn_out", il); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + cb(cur, "layer_out", il); + + inpL = cur; + } + + // post-layernorm + if (model.post_ln_w) { + inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, eps, -1); + } + + ggml_tensor * embeddings = inpL; + + // process vision feature layers (used by granite) + { + // final layer is a vision feature layer + if (vision_feature_layer.find(max_feature_layer) != vision_feature_layer.end()) { + embedding_stack.push_back(inpL); + } + + // If feature layers are explicitly set, stack them (if we have multiple) + if (!embedding_stack.empty()) { + embeddings = embedding_stack[0]; + for (size_t i = 1; i < embedding_stack.size(); i++) { + embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0); + } + } + } + + // llava projector (also used by granite) + if (hparams.has_llava_projector) { + embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); + + ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(patches, "patches"); + ggml_set_input(patches); + + // shape [1, 576, 1024] + // ne is whcn, ne = [1024, 576, 1, 1] + embeddings = ggml_get_rows(ctx0, embeddings, patches); + + // print_tensor_info(embeddings, "embeddings"); + + // llava projector + if (proj_type == PROJECTOR_TYPE_MLP) { + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); + + embeddings = ggml_gelu(ctx0, embeddings); + if (model.mm_2_w) { + embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + } + } + else if (proj_type == PROJECTOR_TYPE_MLP_NORM) { + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); + // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); + // First LayerNorm + embeddings = ggml_norm(ctx0, embeddings, eps); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w), + model.mm_1_b); + + // GELU activation + embeddings = ggml_gelu(ctx0, embeddings); + + // Second linear layer + embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_3_b); + + // Second LayerNorm + embeddings = ggml_norm(ctx0, embeddings, eps); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w), + model.mm_4_b); + } + else if (proj_type == PROJECTOR_TYPE_LDP) { + // MobileVLM projector + int n_patch = 24; + ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings); + mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b); + mlp_1 = ggml_gelu(ctx0, mlp_1); + ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1); + mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b); + // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1] + + // block 1 + ggml_tensor * block_1 = nullptr; + { + // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24] + mlp_3 = ggml_permute(ctx0, mlp_3, 1, 0, 2, 3); + mlp_3 = ggml_cont_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]); + // stride = 1, padding = 1, bias is nullptr + block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1); + + // layer norm + // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] + block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3)); + // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1] + block_1 = ggml_norm(ctx0, block_1, eps); + block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b); + block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3)); + + // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] + // hardswish + ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1); + + block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0); + // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] + // pointwise conv + block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]); + block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1); + block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b); + block_1 = ggml_relu(ctx0, block_1); + block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1); + block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b); + block_1 = ggml_hardsigmoid(ctx0, block_1); + // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1] + block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]); + block_1 = ggml_mul(ctx0, block_1_hw, block_1); + + int w = block_1->ne[0], h = block_1->ne[1]; + block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]); + block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3)); + + // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1] + block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1); + block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]); + + // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1] + block_1 = ggml_norm(ctx0, block_1, eps); + block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b); + block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3)); + // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] + // residual + block_1 = ggml_add(ctx0, mlp_3, block_1); + } + + // block_2 + { + // stride = 2 + block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1); + + // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1] + // layer norm + block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3)); + // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1] + block_1 = ggml_norm(ctx0, block_1, eps); + block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b); + block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3)); + // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1] + // hardswish + ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1); + + // not sure the parameters is right for globalAvgPooling + block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0); + // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] + // pointwise conv + block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]); + block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1); + block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b); + block_1 = ggml_relu(ctx0, block_1); + block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1); + block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b); + block_1 = ggml_hardsigmoid(ctx0, block_1); + + // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] + block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]); + block_1 = ggml_mul(ctx0, block_1_hw, block_1); + + int w = block_1->ne[0], h = block_1->ne[1]; + block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]); + block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3)); + // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1] + block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1); + block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]); + + + // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1] + block_1 = ggml_norm(ctx0, block_1, eps); + block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b); + block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]); + // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1] + } + embeddings = block_1; + } + else if (proj_type == PROJECTOR_TYPE_LDPV2) + { + int n_patch = 24; + ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); + mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b); + mlp_0 = ggml_gelu(ctx0, mlp_0); + ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0); + mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b); + // mlp_2 ne = [2048, 576, 1, 1] + // // AVG Pool Layer 2*2, strides = 2 + mlp_2 = ggml_permute(ctx0, mlp_2, 1, 0, 2, 3); + // mlp_2 ne = [576, 2048, 1, 1] + mlp_2 = ggml_cont_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); + // mlp_2 ne [24, 24, 2048, 1] + mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); + // weight ne = [3, 3, 2048, 1] + ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1); + peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3)); + peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b); + mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3)); + peg_0 = ggml_add(ctx0, peg_0, mlp_2); + peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]); + embeddings = peg_0; + } + else { + GGML_ABORT("fatal error"); + } + } + + // glm projector + else if (proj_type == PROJECTOR_TYPE_GLM_EDGE) { + size_t gridsz = (size_t)sqrt(embeddings->ne[1]); + embeddings = ggml_permute(ctx0,embeddings,1,0,2,3); + embeddings = ggml_cont_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); + embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); + embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); + embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3)); + embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b); + // GLU + { + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); + embeddings = ggml_norm(ctx0, embeddings, eps); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b); + embeddings = ggml_gelu_inplace(ctx0, embeddings); + ggml_tensor * x = embeddings; + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); + x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); + embeddings = ggml_swiglu_split(ctx0, embeddings, x); + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); + } + // arrangement of BOI/EOI token embeddings + // note: these embeddings are not present in text model, hence we cannot process them as text tokens + // see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53 + { + embeddings = ggml_concat(ctx0, model.mm_boi, embeddings, 1); // BOI + embeddings = ggml_concat(ctx0, embeddings, model.mm_eoi, 1); // EOI + } + } + + else { + GGML_ABORT("llava: unknown projector type"); + } + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/minicpmv.cpp b/llama/llama.cpp/tools/mtmd/models/minicpmv.cpp new file mode 100644 index 00000000..3594ea29 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/minicpmv.cpp @@ -0,0 +1,114 @@ +#include "models.h" + +ggml_cgraph * clip_graph_minicpmv::build() { + GGML_ASSERT(model.class_embedding == nullptr); + const int n_pos = n_patches; + const int n_embd_proj = n_mmproj_embd; + + // position embeddings for the projector (not for ViT) + // see: https://huggingface.co/openbmb/MiniCPM-o-2_6/blob/main/resampler.py#L70 + // base frequency omega + ggml_tensor * omega = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_embd_proj / 4); + ggml_set_name(omega, "omega"); + ggml_set_input(omega); + + // 2D input positions (using float for sinusoidal embeddings) + ggml_tensor * pos_h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_pos); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + ggml_tensor * pos_w = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_pos); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + // for selecting learned pos embd, used by ViT + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions); + + ggml_tensor * inp = build_inp(); + ggml_tensor * embeddings = build_vit( + inp, n_pos, + NORM_TYPE_NORMAL, + hparams.ffn_op, + learned_pos_embd, + nullptr); + + // resampler projector (it is just another transformer) + + ggml_tensor * q = model.mm_model_query; + ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings); + + // norm + q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1); + v = build_norm(v, model.mm_model_ln_kv_w, model.mm_model_ln_kv_b, NORM_TYPE_NORMAL, eps, -1); + + // calculate sinusoidal pos embd + ggml_tensor * pos_embed = nullptr; + { + // outer product + ggml_tensor * omega_b = ggml_repeat_4d(ctx0, omega, omega->ne[0], n_pos, 1, 1); // n_pos rows + ggml_tensor * theta_x = ggml_mul(ctx0, omega_b, pos_w); + ggml_tensor * theta_y = ggml_mul(ctx0, omega_b, pos_h); + // sin and cos + ggml_tensor * pos_embd_x = ggml_concat( + ctx0, + ggml_sin(ctx0, theta_x), + ggml_cos(ctx0, theta_x), + 0 // concat on first dim + ); + ggml_tensor * pos_embd_y = ggml_concat( + ctx0, + ggml_sin(ctx0, theta_y), + ggml_cos(ctx0, theta_y), + 0 // concat on first dim + ); + pos_embed = ggml_concat(ctx0, pos_embd_x, pos_embd_y, 0); + } + + // k = v + pos_embed + ggml_tensor * k = ggml_add(ctx0, v, pos_embed); + + // attention + { + const int d_head = 128; + int n_head = n_embd_proj/d_head; + // Use actual config value if available, otherwise fall back to hardcoded values + int num_query = hparams.minicpmv_query_num; + ggml_tensor * Q = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), + model.mm_model_attn_q_b); + ggml_tensor * K = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), + model.mm_model_attn_k_b); + ggml_tensor * V = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), + model.mm_model_attn_v_b); + + Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_query); + K = ggml_reshape_3d(ctx0, K, d_head, n_head, n_pos); + V = ggml_reshape_3d(ctx0, V, d_head, n_head, n_pos); + + cb(Q, "resampler_Q", -1); + cb(K, "resampler_K", -1); + cb(V, "resampler_V", -1); + + float resampler_kq_scale = 1.0f/ sqrtf(float(d_head)); + embeddings = build_attn( + model.mm_model_attn_o_w, + model.mm_model_attn_o_b, + Q, K, V, nullptr, resampler_kq_scale, -1); + cb(embeddings, "resampler_attn_out", -1); + } + // layernorm + embeddings = build_norm(embeddings, model.mm_model_ln_post_w, model.mm_model_ln_post_b, NORM_TYPE_NORMAL, eps, -1); + + // projection + embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings); + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/models.go b/llama/llama.cpp/tools/mtmd/models/models.go new file mode 100644 index 00000000..77fa35e6 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/models.go @@ -0,0 +1,6 @@ +package models + +// #cgo CXXFLAGS: -std=c++17 +// #cgo CPPFLAGS: -I${SRCDIR}/../../../include -I${SRCDIR}/../../../common -I${SRCDIR}/../../../vendor +// #cgo CPPFLAGS: -I${SRCDIR}/../../../../../ml/backend/ggml/ggml/include +import "C" diff --git a/llama/llama.cpp/tools/mtmd/models/models.h b/llama/llama.cpp/tools/mtmd/models/models.h new file mode 100644 index 00000000..0496d6b2 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/models.h @@ -0,0 +1,63 @@ +#pragma once + +#include "../clip-graph.h" + +struct clip_graph_siglip : clip_graph { + clip_graph_siglip(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + +struct clip_graph_pixtral : clip_graph { + clip_graph_pixtral(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + +struct clip_graph_qwen2vl : clip_graph { + clip_graph_qwen2vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + +struct clip_graph_qwen3vl : clip_graph { + clip_graph_qwen3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + +struct clip_graph_minicpmv : clip_graph { + clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + +struct clip_graph_internvl : clip_graph { + clip_graph_internvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + +struct clip_graph_llama4 : clip_graph { + clip_graph_llama4(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + +struct clip_graph_kimivl : clip_graph { + clip_graph_kimivl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + +struct clip_graph_cogvlm : clip_graph { + clip_graph_cogvlm(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + +struct clip_graph_llava : clip_graph { + clip_graph_llava(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + +struct clip_graph_whisper_enc : clip_graph { + clip_graph_whisper_enc(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + +struct clip_graph_glm4v : clip_graph { + clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; diff --git a/llama/llama.cpp/tools/mtmd/models/pixtral.cpp b/llama/llama.cpp/tools/mtmd/models/pixtral.cpp new file mode 100644 index 00000000..a849210b --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/pixtral.cpp @@ -0,0 +1,86 @@ +#include "models.h" + +ggml_cgraph * clip_graph_pixtral::build() { + const int n_merge = hparams.n_merge; + + // 2D input positions + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, true); + }; + + ggml_tensor * inp = build_inp(); + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_RMS, + hparams.ffn_op, + nullptr, // no learned pos embd + add_pos); + + // mistral small 3.1 patch merger + // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67 + if (model.mm_patch_merger_w) { + GGML_ASSERT(hparams.n_merge > 0); + + cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w); + + // reshape image tokens to 2D grid + cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y); + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, n_embd] + cur = ggml_cont(ctx0, cur); + + // torch.nn.functional.unfold is just an im2col under the hood + // we just need a dummy kernel to make it work + ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0); + cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type); + + // project to n_embd + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); + cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur); + } + + // LlavaMultiModalProjector (always using GELU activation) + { + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU, + -1); + } + + // arrangement of the [IMG_BREAK] token + if (model.token_embd_img_break) { + // not efficient, but works + // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows] + // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension + // after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows] + + const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y; + const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x; + const int p_total = p_x * p_y; + const int n_embd_text = cur->ne[0]; + const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row + + ggml_tensor * tmp = ggml_reshape_3d(ctx0, cur, n_embd_text, p_x, p_y); + ggml_tensor * tok = ggml_new_tensor_3d(ctx0, tmp->type, n_embd_text, 1, p_y); + tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor + tok = ggml_add(ctx0, tok, model.token_embd_img_break); + tmp = ggml_concat(ctx0, tmp, tok, 1); + cur = ggml_view_2d(ctx0, tmp, + n_embd_text, n_tokens_output, + ggml_row_size(tmp->type, n_embd_text), 0); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/qwen2vl.cpp b/llama/llama.cpp/tools/mtmd/models/qwen2vl.cpp new file mode 100644 index 00000000..85f158bb --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/qwen2vl.cpp @@ -0,0 +1,183 @@ +#include "models.h" + +ggml_cgraph * clip_graph_qwen2vl::build() { + GGML_ASSERT(model.patch_bias == nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + + const int batch_size = 1; + const bool use_window_attn = hparams.n_wa_pattern > 0; + const int n_wa_pattern = hparams.n_wa_pattern; + const int n_pos = n_patches; + const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position + + norm_type norm_t = proj_type == PROJECTOR_TYPE_QWEN25VL + ? NORM_TYPE_RMS // qwen 2.5 vl + : NORM_TYPE_NORMAL; // qwen 2 vl + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + GGML_ASSERT(img.nx % (patch_size * 2) == 0); + GGML_ASSERT(img.ny % (patch_size * 2) == 0); + + // second conv dimension + { + auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_cont_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + inp = ggml_reshape_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d( + ctx0, inp, + n_embd, n_patches_x * n_patches_y, batch_size); + } + + ggml_tensor * inpL = inp; + ggml_tensor * window_mask = nullptr; + ggml_tensor * window_idx = nullptr; + ggml_tensor * inv_window_idx = nullptr; + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // pre-layernorm + if (model.pre_ln_w) { + inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); + } + + if (use_window_attn) { + // handle window attention inputs + inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); + // mask for window attention + window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + + // if flash attn is used, we need to pad the mask and cast to f16 + if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { + window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); + } + + // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size] + GGML_ASSERT(batch_size == 1); + inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4); + inpL = ggml_get_rows(ctx0, inpL, inv_window_idx); + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size); + } + + // loop over layers + for (int il = 0; il < n_layer; il++) { + const auto & layer = model.layers[il]; + const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true; + + ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); + cb(cur, "ln1", il); + + // self-attention + { + ggml_tensor * Qcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b); + ggml_tensor * Kcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b); + ggml_tensor * Vcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b); + + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches); + Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches); + Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // apply M-RoPE + Qcur = ggml_rope_multi( + ctx0, Qcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + Kcur = ggml_rope_multi( + ctx0, Kcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + ggml_tensor * attn_mask = full_attn ? nullptr : window_mask; + + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, attn_mask, kq_scale, il); + cb(cur, "attn_out", il); + } + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + cb(cur, "ffn_inp", il); + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); + cb(cur, "ffn_inp_normed", il); + + // ffn + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + layer.ff_gate_w, layer.ff_gate_b, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + + cb(cur, "ffn_out", il); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + cb(cur, "layer_out", il); + + inpL = cur; + } + + // post-layernorm + if (model.post_ln_w) { + inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer); + } + + // multimodal projection + ggml_tensor * embeddings = inpL; + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); + embeddings = build_ffn(embeddings, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + FFN_GELU, + -1); + + if (use_window_attn) { + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); + + // embeddings shape: [n_embd, n_patches_x * n_patches_y, batch_size] + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4, batch_size); + } + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/qwen3vl.cpp b/llama/llama.cpp/tools/mtmd/models/qwen3vl.cpp new file mode 100644 index 00000000..35a42cb8 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/qwen3vl.cpp @@ -0,0 +1,191 @@ +#include "models.h" + +ggml_cgraph * clip_graph_qwen3vl::build() { + GGML_ASSERT(model.patch_bias != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + + const int batch_size = 1; + const int n_pos = n_patches; + const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position + + norm_type norm_t = NORM_TYPE_NORMAL; + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + GGML_ASSERT(img.nx % (patch_size * 2) == 0); + GGML_ASSERT(img.ny % (patch_size * 2) == 0); + + // second conv dimension + { + auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_cont_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + inp = ggml_reshape_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d( + ctx0, inp, + n_embd, n_patches_x * n_patches_y, batch_size); + } + + // add patch bias + if (model.patch_bias != nullptr) { + inp = ggml_add(ctx0, inp, model.patch_bias); + cb(inp, "patch_bias", -1); + } + + // calculate absolute position embedding and apply + ggml_tensor * learned_pos_embd = resize_position_embeddings(); + learned_pos_embd = ggml_cont_4d( + ctx0, learned_pos_embd, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + learned_pos_embd = ggml_reshape_4d( + ctx0, learned_pos_embd, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3); + learned_pos_embd = ggml_cont_3d( + ctx0, learned_pos_embd, + n_embd, n_patches_x * n_patches_y, batch_size); + inp = ggml_add(ctx0, inp, learned_pos_embd); + cb(inp, "inp_pos_emb", -1); + + ggml_tensor * inpL = inp; + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // pre-layernorm + if (model.pre_ln_w) { + inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); + } + + // deepstack features (stack along the feature dimension), [n_embd * len(deepstack_layers), n_patches_x * n_patches_y, batch_size] + ggml_tensor * deepstack_features = nullptr; + const int merge_factor = hparams.n_merge > 0 ? hparams.n_merge * hparams.n_merge : 4; // default 2x2=4 for qwen3vl + + // loop over layers + for (int il = 0; il < n_layer; il++) { + auto & layer = model.layers[il]; + + ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); + cb(cur, "ln1", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + cur = ggml_add(ctx0, cur, layer.qkv_b); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ 0); + + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ ggml_row_size(cur->type, n_embd)); + + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ ggml_row_size(cur->type, 2 * n_embd)); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // apply M-RoPE + Qcur = ggml_rope_multi( + ctx0, Qcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + Kcur = ggml_rope_multi( + ctx0, Kcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + cb(cur, "ffn_inp", il); + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); + cb(cur, "ffn_inp_normed", il); + + // ffn + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + layer.ff_gate_w, layer.ff_gate_b, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + + cb(cur, "ffn_out", il); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + cb(cur, "layer_out", il); + + if (layer.has_deepstack()) { + ggml_tensor * feat = ggml_reshape_3d(ctx0, cur, n_embd * merge_factor, n_pos / merge_factor, batch_size); + feat = build_norm(feat, layer.deepstack_norm_w, layer.deepstack_norm_b, norm_t, eps, il); + feat = build_ffn(feat, + layer.deepstack_fc1_w, layer.deepstack_fc1_b, + nullptr, nullptr, + layer.deepstack_fc2_w, layer.deepstack_fc2_b, + ffn_op_type::FFN_GELU, il); + + if(!deepstack_features) { + deepstack_features = feat; + } else { + // concat along the feature dimension + deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0); + } + } + + inpL = cur; + } + + // post-layernorm + if (model.post_ln_w) { + inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer); + } + + // multimodal projection + ggml_tensor * embeddings = inpL; + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); + + embeddings = build_ffn(embeddings, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + ffn_op_type::FFN_GELU, -1); + + embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/siglip.cpp b/llama/llama.cpp/tools/mtmd/models/siglip.cpp new file mode 100644 index 00000000..ef094cfd --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/siglip.cpp @@ -0,0 +1,81 @@ +#include "models.h" + +ggml_cgraph * clip_graph_siglip::build() { + ggml_tensor * inp = build_inp(); + + ggml_tensor * learned_pos_embd = model.position_embeddings; + if (proj_type == PROJECTOR_TYPE_LFM2) { + learned_pos_embd = resize_position_embeddings(); + } + + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_NORMAL, + hparams.ffn_op, + learned_pos_embd, + nullptr); + + if (proj_type == PROJECTOR_TYPE_GEMMA3) { + const int batch_size = 1; + GGML_ASSERT(n_patches_x == n_patches_y); + const int patches_per_image = n_patches_x; + const int kernel_size = hparams.n_merge; + + cur = ggml_transpose(ctx0, cur); + cur = ggml_cont_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size); + + // doing a pool2d to reduce the number of output tokens + cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0); + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[0], n_embd, batch_size); + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + // apply norm before projection + cur = ggml_rms_norm(ctx0, cur, eps); + cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w); + + // apply projection + cur = ggml_mul_mat(ctx0, + ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)), + cur); + + } else if (proj_type == PROJECTOR_TYPE_IDEFICS3) { + // pixel_shuffle + // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 + const int scale_factor = model.hparams.n_merge; + cur = build_patch_merge_permute(cur, scale_factor); + cur = ggml_mul_mat(ctx0, model.projection, cur); + + } else if (proj_type == PROJECTOR_TYPE_LFM2) { + // pixel unshuffle block + const int scale_factor = model.hparams.n_merge; + cur = build_patch_merge_permute(cur, scale_factor); + + // projection + cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm + cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); + cur = ggml_add(ctx0, cur, model.mm_input_norm_b); + + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU, + -1); + + } else if (proj_type == PROJECTOR_TYPE_JANUS_PRO) { + cur = build_ffn(cur, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + hparams.ffn_op, + -1); + + } else { + GGML_ABORT("SigLIP: Unsupported projector type"); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/whisper-enc.cpp b/llama/llama.cpp/tools/mtmd/models/whisper-enc.cpp new file mode 100644 index 00000000..2870d854 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/whisper-enc.cpp @@ -0,0 +1,106 @@ +#include "models.h" + +ggml_cgraph * clip_graph_whisper_enc::build() { + const int n_frames = img.nx; + const int n_pos = n_frames / 2; + GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos); + + ggml_tensor * inp = build_inp_raw(1); + + // conv1d block + { + // convolution + gelu + ggml_tensor * cur = ggml_conv_1d_ph(ctx0, model.conv1d_1_w, inp, 1, 1); + cur = ggml_add(ctx0, cur, model.conv1d_1_b); + + cur = ggml_gelu_erf(ctx0, cur); + + cur = ggml_conv_1d_ph(ctx0, model.conv1d_2_w, cur, 2, 1); + cur = ggml_add(ctx0, cur, model.conv1d_2_b); + + cur = ggml_gelu_erf(ctx0, cur); + // transpose + inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + cb(inp, "after_conv1d", -1); + } + + // sanity check (only check one layer, but it should be the same for all) + GGML_ASSERT(model.layers[0].ln_1_w && model.layers[0].ln_1_b); + GGML_ASSERT(model.layers[0].ln_2_w && model.layers[0].ln_2_b); + GGML_ASSERT(model.layers[0].q_b); + GGML_ASSERT(model.layers[0].v_b); + GGML_ASSERT(!model.layers[0].k_b); // no bias for k + + ggml_tensor * pos_embd_selected = ggml_view_2d( + ctx0, model.position_embeddings, + model.position_embeddings->ne[0], n_pos, + model.position_embeddings->nb[1], 0 + ); + ggml_tensor * cur = build_vit( + inp, n_pos, + NORM_TYPE_NORMAL, + hparams.ffn_op, + pos_embd_selected, + nullptr); + + cb(cur, "after_transformer", -1); + + if (model.audio_has_stack_frames()) { + // StackAudioFrames + // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py + cur = build_stack(cur, hparams.proj_stack_factor, n_embd); + cb(cur, "after_stacked", -1); + } + + if (proj_type == PROJECTOR_TYPE_ULTRAVOX) { + // UltravoxProjector + // pre-norm + cur = ggml_rms_norm(ctx0, cur, 1e-6); + cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); + + // ffn in + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + + // swiglu + // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half + cur = ggml_swiglu_swapped(ctx0, cur); + + // mid-norm + cur = ggml_rms_norm(ctx0, cur, 1e-6); + cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w); + + // ffn out + cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + + } else if (proj_type == PROJECTOR_TYPE_QWEN2A) { + // projector + cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur); + cur = ggml_add(ctx0, cur, model.mm_fc_b); + + } else if (proj_type == PROJECTOR_TYPE_VOXTRAL) { + // projector + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU_ERF, + -1); + + } else if (proj_type == PROJECTOR_TYPE_GLMA) { + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); + cur = ggml_add(ctx0, cur, model.mm_norm_pre_b); + cur = build_stack(cur, hparams.proj_stack_factor, n_embd); + cur = build_ffn(cur, model.mm_1_w, model.mm_1_b, nullptr, nullptr, model.mm_2_w, model.mm_2_b, hparams.ffn_op, 0); + cur = ggml_concat(ctx0, model.mm_boi, cur, 1); + cur = ggml_concat(ctx0, cur, model.mm_eoi, 1); + } else { + GGML_ABORT("%s: unknown projector type", __func__); + } + + cb(cur, "projected", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/mtmd-audio.cpp b/llama/llama.cpp/tools/mtmd/mtmd-audio.cpp index 84bdc277..2024d3d3 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd-audio.cpp +++ b/llama/llama.cpp/tools/mtmd/mtmd-audio.cpp @@ -11,63 +11,149 @@ // most of the code here is copied from whisper.cpp -// align x to upper multiple of n -#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) +constexpr bool DEBUG = false; -namespace whisper_preprocessor { +struct mtmd_audio_mel_filters { + int32_t n_mel; + int32_t n_fft; -#define SIN_COS_N_COUNT WHISPER_N_FFT -namespace { -struct whisper_global_cache { - // In FFT, we frequently use sine and cosine operations with the same values. - // We can use precalculated values to speed up the process. - float sin_vals[SIN_COS_N_COUNT]; - float cos_vals[SIN_COS_N_COUNT]; + std::vector data; +}; - // Hann window (Use cosf to eliminate difference) - // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html - // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 - float hann_window[WHISPER_N_FFT]; +// note: this global cache is shared among all preprocessors +// if we want to use multiple preprocessors at the same time, +// we will need to enclose it in the preprocessor class in the future +static struct mtmd_audio_global_cache { + // precomputed sin/cos table for FFT + std::vector sin_vals; + std::vector cos_vals; - whisper_global_cache() { - fill_sin_cos_table(); - fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window); - } + // hann window + std::vector hann_window; - void fill_sin_cos_table() { - for (int i = 0; i < SIN_COS_N_COUNT; i++) { - double theta = (2 * M_PI * i) / SIN_COS_N_COUNT; + // mel filter bank + mtmd_audio_mel_filters filters; + + void fill_sin_cos_table(int n) { + sin_vals.resize(n); + cos_vals.resize(n); + for (int i = 0; i < n; i++) { + double theta = (2 * M_PI * i) / n; sin_vals[i] = sinf(theta); cos_vals[i] = cosf(theta); } } - void fill_hann_window(int length, bool periodic, float * output) { + void fill_hann_window(int length, bool periodic) { + hann_window.resize(length); int offset = -1; if (periodic) { offset = 0; } for (int i = 0; i < length; i++) { - output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); } } -} global_cache; -} + + // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. + // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. + void fill_mel_filterbank_matrix( + int n_mel, + int n_fft, + int sample_rate, // e.g. 16000 + float fmin = 0.0f, // e.g. 0.0 + float fmax = -1.0f, // e.g. sr/2; pass -1 for auto + bool slaney_area_norm = true, + float scale = 1.0f // optional extra scaling; use 1.0f/1000.0f to mimic your code + ) { + GGML_ASSERT(n_mel > 0 && n_fft > 1); + if (fmax <= 0.0f) { + fmax = 0.5f * sample_rate; + } + + // Slaney scale (matches librosa default) + const double min_log_hz = 1000.0; + const double lin_slope = 3 / 200.; + const double min_log_mel = min_log_hz * lin_slope; + const double log_step = log(6.4) / 27.0; + auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { + return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; + }; + auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { + return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); + }; + + // infer N_fft from n_fft_bins + const double bin_hz_step = double(sample_rate) / double(n_fft); + + // mel grid: n_mel + 2 edges + const double m_lo = hz_to_mel(fmin); + const double m_hi = hz_to_mel(fmax); + std::vector mel_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); + } + + // convert to Hz + std::vector hz_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + hz_pts[i] = mel_to_hz(mel_pts[i]); + } + + const int n_fft_bins = n_fft / 2 + 1; + + // filterbank + std::vector out(n_mel * n_fft_bins, 0); + for (int m = 0; m < n_mel; ++m) { + const double f_left = hz_pts[m]; + const double f_center = hz_pts[m + 1]; + const double f_right = hz_pts[m + 2]; + + const double denom_l = std::max(1e-30, f_center - f_left); + const double denom_r = std::max(1e-30, f_right - f_center); + const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; + + for (int k = 0; k < n_fft_bins; ++k) { + const double f = k * bin_hz_step; + double w = 0.0; + if (f >= f_left && f <= f_center) { + w = (f - f_left) / denom_l; + } else if (f > f_center && f <= f_right) { + w = (f_right - f) / denom_r; + } + out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); + } + } + + filters.n_mel = n_mel; + filters.n_fft = n_fft; + filters.data = std::move(out); + + if (DEBUG) { // debug + for (size_t i = 0; i < filters.data.size(); ++i) { + if (filters.data[i] != 0.0f) { + printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); + } + } + } + } +} g_cache; // naive Discrete Fourier Transform // input is real-valued // output is complex-valued -static void dft(const float* in, int N, float* out) { - const int sin_cos_step = SIN_COS_N_COUNT / N; +static void dft(const float * in, int N, float * out) { + const int n_sin_cos_vals = g_cache.sin_vals.size(); + const int sin_cos_step = n_sin_cos_vals / N; for (int k = 0; k < N; k++) { float re = 0; float im = 0; for (int n = 0; n < N; n++) { - int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N - re += in[n]*global_cache.cos_vals[idx]; // cos(t) - im -= in[n]*global_cache.sin_vals[idx]; // sin(t) + int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N + re += in[n] * g_cache.cos_vals[idx]; // cos(t) + im -= in[n] * g_cache.sin_vals[idx]; // sin(t) } out[k*2 + 0] = re; @@ -79,7 +165,8 @@ static void dft(const float* in, int N, float* out) { // poor man's implementation - use something better // input is real-valued // output is complex-valued -static void fft(float* in, int N, float* out) { +static void fft(float * in, int N, float * out) { + const int n_sin_cos_vals = g_cache.sin_vals.size(); if (N == 1) { out[0] = in[0]; out[1] = 0; @@ -106,11 +193,11 @@ static void fft(float* in, int N, float* out) { float* odd_fft = even_fft + N; fft(odd, half_N, odd_fft); - const int sin_cos_step = SIN_COS_N_COUNT / N; + const int sin_cos_step = n_sin_cos_vals / N; for (int k = 0; k < half_N; k++) { int idx = k * sin_cos_step; // t = 2*M_PI*k/N - float re = global_cache.cos_vals[idx]; // cos(t) - float im = -global_cache.sin_vals[idx]; // sin(t) + float re = g_cache.cos_vals[idx]; // cos(t) + float im = -g_cache.sin_vals[idx]; // sin(t) float re_odd = odd_fft[2*k + 0]; float im_odd = odd_fft[2*k + 1]; @@ -123,20 +210,34 @@ static void fft(float* in, int N, float* out) { } } +struct filter_params { + int32_t n_mel; + int32_t n_fft_bins; + int32_t hann_window_size; + int32_t hop_length; + int32_t sample_rate; + bool center_padding = false; + float preemph = 0.f; + bool use_natural_log = false; + bool norm_per_feature = false; +}; + static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, int n_samples, int frame_size, int frame_step, int n_threads, - const whisper_filters & filters, whisper_mel & mel) { + const filter_params & params, mtmd_audio_mel & out) { std::vector fft_in(frame_size * 2, 0.0); std::vector fft_out(frame_size * 2 * 2 * 2); - int n_fft = filters.n_fft; + int n_fft_bins = params.n_fft_bins; int i = ith; - // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist - WHISPER_ASSERT(n_fft == 1 + (frame_size / 2)); + const auto & filters = g_cache.filters; + // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist + GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2)); + GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size()); // calculate FFT only when fft_in are not all zero - for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { + for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) { const int offset = i * frame_step; // apply Hann window (~10% faster) @@ -154,36 +255,39 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. - for (int j = 0; j < n_fft; j++) { + for (int j = 0; j < n_fft_bins; j++) { fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); } // mel spectrogram - for (int j = 0; j < mel.n_mel; j++) { + for (int j = 0; j < out.n_mel; j++) { double sum = 0.0; // unroll loop (suggested by GH user @lunixbochs) int k = 0; - for (k = 0; k < n_fft - 3; k += 4) { + for (k = 0; k < n_fft_bins - 3; k += 4) { + size_t idx = size_t(j) * size_t(n_fft_bins) + size_t(k); sum += - fft_out[k + 0] * filters.data[j * n_fft + k + 0] + - fft_out[k + 1] * filters.data[j * n_fft + k + 1] + - fft_out[k + 2] * filters.data[j * n_fft + k + 2] + - fft_out[k + 3] * filters.data[j * n_fft + k + 3]; + fft_out[k + 0] * filters.data[idx + 0] + + fft_out[k + 1] * filters.data[idx + 1] + + fft_out[k + 2] * filters.data[idx + 2] + + fft_out[k + 3] * filters.data[idx + 3]; } // handle n_fft remainder - for (; k < n_fft; k++) { - sum += fft_out[k] * filters.data[j * n_fft + k]; + for (; k < n_fft_bins; k++) { + sum += fft_out[k] * filters.data[j * n_fft_bins + k]; } - sum = log10(std::max(sum, 1e-10)); - mel.data[j * mel.n_len + i] = sum; + sum = params.use_natural_log + ? log(sum + 5.960464477539063e-08) + : log10(std::max(sum, 1e-10)); + out.data[j * out.n_len + i] = sum; } } // Otherwise fft_out are all zero - double sum = log10(1e-10); - for (; i < mel.n_len; i += n_threads) { - for (int j = 0; j < mel.n_mel; j++) { - mel.data[j * mel.n_len + i] = sum; + double sum = params.use_natural_log ? log(1e-10) : log10(1e-10); + for (; i < out.n_len; i += n_threads) { + for (int j = 0; j < out.n_mel; j++) { + out.data[j * out.n_len + i] = sum; } } } @@ -191,115 +295,212 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 static bool log_mel_spectrogram( const float * samples, - const int n_samples, - const int /*sample_rate*/, - const int frame_size, - const int frame_step, - const int n_mel, - const int n_threads, - const whisper_filters & filters, - const bool debug, - whisper_mel & mel) { + const int n_samples_in, + const int n_threads, + const filter_params & params, + mtmd_audio_mel & out) { //const int64_t t_start_us = ggml_time_us(); + out.n_len_org = n_samples_in; + int n_samples = n_samples_in; + // Hann window - WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size"); - const float * hann = global_cache.hann_window; + const float * hann = g_cache.hann_window.data(); + const int frame_size = (params.n_fft_bins - 1) * 2; + const int frame_step = params.hop_length; - // Calculate the length of padding - int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; - int64_t stage_2_pad = frame_size / 2; - - // Initialize a vector and copy data from C array to it. + // Padding std::vector samples_padded; - samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); - std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); + if (params.center_padding) { + const auto pad_amount = frame_size / 2; + samples_padded = std::vector(n_samples + 2 * pad_amount, 0); + std::copy(samples, samples + n_samples, samples_padded.data() + pad_amount); + samples = samples_padded.data(); + n_samples = samples_padded.size(); + } else { + // existing padding logic + int64_t stage_1_pad = params.sample_rate * 30; + int64_t stage_2_pad = frame_size / 2; + samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); + std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); + // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio + std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); + // reflective pad 200 samples at the beginning of audio + if (n_samples < stage_2_pad + 1) { + // TODO: Handle short audio differently or return error + return false; + } + std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); + } - // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio - std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); + // preemphasis + if (params.preemph) { + const int pad_amount = frame_size / 2; + const float preemph = 0.97f; + float prev = samples_padded[pad_amount]; + for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) { + float cur = samples_padded[i]; + samples_padded[i] = cur - preemph * prev; + prev = cur; + } + } - // reflective pad 200 samples at the beginning of audio - std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); + // pad hann window if it's smaller than frame_size + // TODO: probably unnecessary here? (or better doing it in g_cache?) + std::vector hann_window_padded; + if (params.hann_window_size < frame_size) { + hann_window_padded.resize(frame_size); + const int padding = (frame_size - params.hann_window_size) / 2; + std::copy(hann, hann + params.hann_window_size, &hann_window_padded[padding]); + hann = hann_window_padded.data(); + } - mel.n_mel = n_mel; - // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 - // Calculate number of frames + remove the last frame - mel.n_len = (samples_padded.size() - frame_size) / frame_step; - // Calculate semi-padded sample length to ensure compatibility - mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step; - mel.data.resize(mel.n_mel * mel.n_len); + + out.n_mel = params.n_mel; + out.n_len = (n_samples - frame_size) / frame_step + 1; + // TODO: handle these checks better + if (out.n_mel > 0 && (unsigned long)out.n_len > SIZE_MAX / out.n_mel) { + LOG_ERR("%s: size overflow\n", __func__); + return false; + } + if (n_samples < frame_size) { + LOG_ERR("%s: not enough samples after padding\n", __func__); + return false; + } + out.data.resize(out.n_mel * out.n_len); { std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw] = std::thread( log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), - n_samples + stage_2_pad, frame_size, frame_step, n_threads, - std::cref(filters), std::ref(mel)); + n_samples, frame_size, frame_step, n_threads, + std::cref(params), std::ref(out)); } // main thread - log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel); - + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw].join(); } } - // clamping and normalization - double mmax = -1e20; - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { - if (mel.data[i] > mmax) { - mmax = mel.data[i]; + const int effective_n_len = n_samples_in / frame_step; + if (params.norm_per_feature) { + for (int i = 0; i < out.n_mel; i++) { + double mean = 0; + for (int j = 0; j < effective_n_len; ++j) { + mean += out.data[i * out.n_len + j]; + } + mean /= effective_n_len; + + double var = 0.0; + for (int j = 0; j < effective_n_len; ++j) { + const double value = out.data[i * out.n_len + j] - mean; + var += value * value; + } + var /= effective_n_len - 1; // unbiased + const double mstd = std::sqrt(var + 1e-5); + + for (int j = 0; j < effective_n_len; ++j) { + auto &value = out.data[i * out.n_len + j]; + value = (value - mean) / mstd; + } + + // pad the rest with zeros + for (int j = effective_n_len; j < out.n_len; ++j) { + out.data[i * out.n_len + j] = 0.0; + } } - } - - mmax -= 8.0; - - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { - if (mel.data[i] < mmax) { - mel.data[i] = mmax; + } else { + // clamping and normalization + double mmax = -1e20; + for (int i = 0; i < out.n_mel*out.n_len; i++) { + if (out.data[i] > mmax) { + mmax = out.data[i]; + } } - mel.data[i] = (mel.data[i] + 4.0)/4.0; + mmax -= 8.0; + + for (int i = 0; i < out.n_mel*out.n_len; i++) { + if (out.data[i] < mmax) { + out.data[i] = mmax; + } + out.data[i] = (out.data[i] + 4.0)/4.0; + } } // Dump log_mel_spectrogram - if (debug) { + if (DEBUG) { std::ofstream outFile("log_mel_spectrogram.json"); outFile << "["; - for (uint64_t i = 0; i < mel.data.size() - 1; i++) { - outFile << mel.data[i] << ", "; + for (uint64_t i = 0; i < out.data.size() - 1; i++) { + outFile << out.data[i] << ", "; } - outFile << mel.data[mel.data.size() - 1] << "]"; + outFile << out.data[out.data.size() - 1] << "]"; outFile.close(); } return true; } -bool preprocess_audio( +// +// mtmd_audio_preprocessor_whisper +// + +void mtmd_audio_preprocessor_whisper::initialize() { + g_cache.fill_sin_cos_table(hparams.audio_n_fft); + g_cache.fill_hann_window(hparams.audio_window_len, true); + g_cache.fill_mel_filterbank_matrix( + hparams.n_mel_bins, + hparams.audio_n_fft, + hparams.audio_sample_rate); +} + +bool mtmd_audio_preprocessor_whisper::preprocess( const float * samples, size_t n_samples, - const whisper_filters & filters, - std::vector & output) { - + std::vector & output) { if (n_samples == 0) { // empty audio return false; } - whisper_mel out_full; + std::vector smpl; + // if input is too short, pad with zeros + // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram + // TODO: maybe handle this better + size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin + if (n_samples < min_samples) { + smpl.resize(min_samples, 0.0f); + std::memcpy(smpl.data(), samples, n_samples * sizeof(float)); + samples = smpl.data(); + n_samples = smpl.size(); + } + + filter_params params; + params.n_mel = hparams.n_mel_bins; + params.n_fft_bins = 1 + (hparams.audio_n_fft / 2); + params.hann_window_size = hparams.audio_window_len; + params.hop_length = hparams.audio_hop_len; + params.sample_rate = hparams.audio_sample_rate; + params.center_padding = false; + params.preemph = 0.0f; // disabled + params.use_natural_log = false; + params.norm_per_feature = false; + + // make sure the global cache is initialized + GGML_ASSERT(!g_cache.sin_vals.empty()); + GGML_ASSERT(!g_cache.cos_vals.empty()); + GGML_ASSERT(!g_cache.filters.data.empty()); + + mtmd_audio_mel out_full; bool ok = log_mel_spectrogram( samples, n_samples, - COMMON_SAMPLE_RATE, - WHISPER_N_FFT, - WHISPER_HOP_LENGTH, - filters.n_mel, 4, // n_threads - filters, - false, // debug + params, out_full); if (!ok) { return false; @@ -307,7 +508,9 @@ bool preprocess_audio( // because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel // we always expect the mel to have 3000 silent frames at the end - // printf("n_len %d\n", out_full.n_len); + if (DEBUG) { + printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len); + } const size_t frames_per_chunk = 3000; GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk); for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) { @@ -316,7 +519,7 @@ bool preprocess_audio( break; // last uncomplete chunk will always be a padded chunk, safe to ignore } - whisper_mel out_chunk; + mtmd_audio_mel out_chunk; out_chunk.n_len = n_len; out_chunk.n_mel = out_full.n_mel; out_chunk.n_len_org = out_full.n_mel; // unused @@ -332,438 +535,3 @@ bool preprocess_audio( return true; } - -} // namespace whisper_preprocessor - - -// precalculated mel filter banks -// values are multiplied by 1000.0 to save space, and will be divided by 1000.0 in the end of the function -// -// generated from python code: -// -// from numpy import load -// data = load('mel_filters.npz') -// lst = data.files -// for item in lst: -// print(item) -// print(data[item].shape) -// n_mel = data[item].shape[0] -// n_fft = data[item].shape[1] -// for i, row in enumerate(data[item]): -// for j, val in enumerate(row): -// val = val * 1000.0 -// if val != 0: -// print(f"data[{i*n_fft + j}] = {val:.6f};") - -namespace whisper_precalc_filters { - -whisper_preprocessor::whisper_filters get_128_bins() { - whisper_preprocessor::whisper_filters filters; - filters.n_mel = 128; - filters.n_fft = 201; - std::vector data(filters.n_mel * filters.n_fft, 0.0f); - - data[1] = 12.37398665; - data[202] = 30.39256483; - data[404] = 24.74797331; - data[605] = 18.01857911; - data[807] = 37.12195903; - data[1008] = 5.64459199; - data[1009] = 6.72939420; - data[1210] = 36.03715822; - data[1412] = 19.10337992; - data[1613] = 23.66316877; - data[1815] = 31.47736564; - data[2016] = 11.28918398; - data[2017] = 1.08480197; - data[2218] = 41.68175161; - data[2420] = 13.45878839; - data[2621] = 29.30776216; - data[2823] = 25.83277412; - data[3024] = 16.93377644; - data[3226] = 38.20675984; - data[3427] = 4.55979025; - data[3428] = 7.81419594; - data[3629] = 34.95235741; - data[3831] = 20.18818259; - data[4032] = 22.57836796; - data[4234] = 32.56217018; - data[4435] = 10.20438317; - data[4436] = 2.16960395; - data[4637] = 40.59694707; - data[4839] = 14.54358920; - data[5040] = 28.22295949; - data[5242] = 26.91757679; - data[5443] = 15.84897563; - data[5645] = 39.29156065; - data[5846] = 3.47498828; - data[5847] = 8.89899861; - data[6048] = 33.86755288; - data[6250] = 21.27298526; - data[6451] = 21.49356715; - data[6653] = 33.64697099; - data[6854] = 9.11958050; - data[6855] = 3.25440569; - data[7056] = 39.51214626; - data[7258] = 15.62839188; - data[7459] = 27.13815868; - data[7661] = 28.00237760; - data[7862] = 14.76417296; - data[8064] = 40.37636518; - data[8265] = 2.38068704; - data[8266] = 10.20263787; - data[8467] = 31.61146119; - data[8669] = 24.54700135; - data[8870] = 15.32919332; - data[8871] = 1.66583748; - data[9072] = 36.72905266; - data[9274] = 20.09709924; - data[9475] = 16.93102531; - data[9476] = 2.90265540; - data[9677] = 32.84499049; - data[9879] = 23.52004871; - data[10080] = 11.03894413; - data[10081] = 10.72582975; - data[10282] = 22.71829173; - data[10484] = 32.27872774; - data[10685] = 0.11626833; - data[10686] = 22.85348251; - data[10887] = 8.56344029; - data[10888] = 14.97978810; - data[11089] = 15.51398356; - data[11090] = 8.51490628; - data[11291] = 21.10680379; - data[11292] = 3.32652032; - data[11493] = 25.47064796; - data[11695] = 27.35907957; - data[11896] = 0.65853616; - data[11897] = 23.83812517; - data[12098] = 3.44359246; - data[12099] = 21.22455277; - data[12300] = 5.35842171; - data[12301] = 19.42555793; - data[12502] = 6.49324711; - data[12503] = 18.35542172; - data[12704] = 6.93138083; - data[12705] = 17.93504693; - data[12906] = 6.74968259; - data[12907] = 18.09151843; - data[13108] = 6.01899112; - data[13109] = 18.75767298; - data[13310] = 4.80452832; - data[13311] = 19.87172849; - data[13512] = 3.16627859; - data[13513] = 21.37690969; - data[13514] = 1.25317345; - data[13714] = 1.15934468; - data[13715] = 20.80361731; - data[13716] = 4.04486805; - data[13917] = 17.55363122; - data[13918] = 7.08320038; - data[14119] = 14.07538634; - data[14120] = 10.32655034; - data[14321] = 10.40921453; - data[14322] = 13.73696327; - data[14523] = 6.59187697; - data[14524] = 17.27988198; - data[14525] = 1.46804214; - data[14725] = 2.65681883; - data[14726] = 18.09193194; - data[14727] = 5.85655728; - data[14928] = 13.34277913; - data[14929] = 10.28267574; - data[15130] = 8.56800377; - data[15131] = 14.72230814; - data[15132] = 1.04039861; - data[15332] = 3.79085587; - data[15333] = 17.14678481; - data[15334] = 6.11609267; - data[15535] = 11.75929047; - data[15536] = 11.13393717; - data[15737] = 6.43857848; - data[15738] = 16.07806236; - data[15739] = 4.23917221; - data[15939] = 1.19989377; - data[15940] = 12.75671553; - data[15941] = 9.65298992; - data[16142] = 7.06935255; - data[16143] = 14.94054683; - data[16144] = 4.19024844; - data[16344] = 1.51483389; - data[16345] = 12.00899947; - data[16346] = 9.84823331; - data[16547] = 6.10224018; - data[16548] = 15.33857174; - data[16549] = 5.57676842; - data[16749] = 0.36827257; - data[16750] = 9.89749376; - data[16751] = 11.35340426; - data[16752] = 2.05122307; - data[16952] = 3.89297144; - data[16953] = 12.97352277; - data[16954] = 8.06631614; - data[17155] = 6.74493238; - data[17156] = 13.85874674; - data[17157] = 5.41190524; - data[17357] = 0.74220158; - data[17358] = 8.98779090; - data[17359] = 11.37871388; - data[17360] = 3.32958088; - data[17560] = 2.82313535; - data[17561] = 10.68049297; - data[17562] = 9.43340641; - data[17563] = 1.76325557; - data[17763] = 4.39018616; - data[17764] = 11.87758986; - data[17765] = 7.97005836; - data[17766] = 0.66104700; - data[17966] = 5.49466675; - data[17967] = 12.62953598; - data[17968] = 6.93987962; - data[18169] = 6.18401915; - data[18170] = 12.93473132; - data[18171] = 6.29778765; - data[18371] = 0.02325210; - data[18372] = 6.50206627; - data[18373] = 12.32661773; - data[18374] = 6.00216538; - data[18574] = 0.31548753; - data[18575] = 6.48925547; - data[18576] = 12.04130240; - data[18577] = 6.01462880; - data[18777] = 0.29979556; - data[18778] = 6.18288014; - data[18779] = 12.04272825; - data[18780] = 6.29981188; - data[18781] = 0.55689598; - data[18980] = 0.01120471; - data[18981] = 5.61729167; - data[18982] = 11.22337859; - data[18983] = 6.82516303; - data[18984] = 1.35264499; - data[19184] = 4.82410006; - data[19185] = 10.16623247; - data[19186] = 7.56075513; - data[19187] = 2.34590308; - data[19387] = 3.83235747; - data[19388] = 8.92296247; - data[19389] = 8.47910438; - data[19390] = 3.50978645; - data[19590] = 2.66873185; - data[19591] = 7.51965167; - data[19592] = 9.55500547; - data[19593] = 4.81966138; - data[19594] = 0.08431751; - data[19793] = 1.35767367; - data[19794] = 5.98019501; - data[19795] = 10.60271543; - data[19796] = 6.25298498; - data[19797] = 1.74059917; - data[19997] = 4.32644226; - data[19998] = 8.73131864; - data[19999] = 7.78916525; - data[20000] = 3.48923868; - data[20200] = 2.57835095; - data[20201] = 6.77582854; - data[20202] = 9.40941647; - data[20203] = 5.31194592; - data[20204] = 1.21447595; - data[20403] = 0.75411191; - data[20404] = 4.75395704; - data[20405] = 8.75380263; - data[20406] = 7.19209015; - data[20407] = 3.28754401; - data[20607] = 2.68179690; - data[20608] = 6.49331464; - data[20609] = 9.11457930; - data[20610] = 5.39387390; - data[20611] = 1.67316827; - data[20810] = 0.57394296; - data[20811] = 4.20600036; - data[20812] = 7.83805829; - data[20813] = 7.52023002; - data[20814] = 3.97470826; - data[20815] = 0.42918732; - data[21014] = 1.90464477; - data[21015] = 5.36569161; - data[21016] = 8.82673822; - data[21017] = 6.27609482; - data[21018] = 2.89750961; - data[21218] = 2.89885257; - data[21219] = 6.19694078; - data[21220] = 8.56699049; - data[21221] = 5.34748193; - data[21222] = 2.12797290; - data[21421] = 0.44750227; - data[21422] = 3.59030394; - data[21423] = 6.73310598; - data[21424] = 7.77023612; - data[21425] = 4.70231380; - data[21426] = 1.63439126; - data[21625] = 1.01536023; - data[21626] = 4.01018746; - data[21627] = 7.00501446; - data[21628] = 7.23442994; - data[21629] = 4.31095669; - data[21630] = 1.38748321; - data[21829] = 1.33348850; - data[21830] = 4.18730825; - data[21831] = 7.04112789; - data[21832] = 6.93188375; - data[21833] = 4.14605811; - data[21834] = 1.36023236; - data[22033] = 1.42879714; - data[22034] = 4.14824858; - data[22035] = 6.86769979; - data[22036] = 6.83705276; - data[22037] = 4.18239459; - data[22038] = 1.52773573; - data[22237] = 1.32610439; - data[22238] = 3.91751388; - data[22239] = 6.50892360; - data[22240] = 6.92639686; - data[22241] = 4.39672917; - data[22242] = 1.86706171; - data[22441] = 1.04827771; - data[22442] = 3.51767405; - data[22443] = 5.98707050; - data[22444] = 7.17824046; - data[22445] = 4.76767914; - data[22446] = 2.35711760; - data[22645] = 0.61636406; - data[22646] = 2.96949223; - data[22647] = 5.32262027; - data[22648] = 7.57265091; - data[22649] = 5.27558755; - data[22650] = 2.97852419; - data[22651] = 0.68146095; - data[22849] = 0.04971400; - data[22850] = 2.29204819; - data[22851] = 4.53438237; - data[22852] = 6.77671656; - data[22853] = 5.90240723; - data[22854] = 3.71349836; - data[22855] = 1.52458926; - data[23054] = 1.50285335; - data[23055] = 3.63961048; - data[23056] = 5.77636715; - data[23057] = 6.63159089; - data[23058] = 4.54574358; - data[23059] = 2.45989650; - data[23060] = 0.37404924; - data[23258] = 0.61795861; - data[23259] = 2.65410915; - data[23260] = 4.69025923; - data[23261] = 6.72641024; - data[23262] = 5.46034705; - data[23263] = 3.47270933; - data[23264] = 1.48507138; - data[23463] = 1.59233576; - data[23464] = 3.53261665; - data[23465] = 5.47289755; - data[23466] = 6.44368259; - data[23467] = 4.54962999; - data[23468] = 2.65557761; - data[23469] = 0.76152512; - data[23667] = 0.46749352; - data[23668] = 2.31641904; - data[23669] = 4.16534441; - data[23670] = 6.01426978; - data[23671] = 5.67844696; - data[23672] = 3.87357362; - data[23673] = 2.06870004; - data[23674] = 0.26382666; - data[23872] = 1.05349103; - data[23873] = 2.81536230; - data[23874] = 4.57723346; - data[23875] = 6.33910485; - data[23876] = 5.12815686; - data[23877] = 3.40826320; - data[23878] = 1.68837002; - data[24077] = 1.43350090; - data[24078] = 3.11241671; - data[24079] = 4.79133241; - data[24080] = 6.40943693; - data[24081] = 4.77052201; - data[24082] = 3.13160778; - data[24083] = 1.49269309; - data[24281] = 0.02932359; - data[24282] = 1.62918994; - data[24283] = 3.22905602; - data[24284] = 4.82892245; - data[24285] = 6.14671456; - data[24286] = 4.58496623; - data[24287] = 3.02321767; - data[24288] = 1.46146910; - data[24486] = 0.13601698; - data[24487] = 1.66055572; - data[24488] = 3.18509457; - data[24489] = 4.70963307; - data[24490] = 6.04072399; - data[24491] = 4.55250870; - data[24492] = 3.06429295; - data[24493] = 1.57607743; - data[24494] = 0.08786193; - data[24691] = 0.09328097; - data[24692] = 1.54603878; - data[24693] = 2.99879676; - data[24694] = 4.45155473; - data[24695] = 5.90431225; - data[24696] = 4.65566106; - data[24697] = 3.23751615; - data[24698] = 1.81937125; - data[24699] = 0.40122634; - data[24897] = 1.30262633; - data[24898] = 2.68698297; - data[24899] = 4.07133950; - data[24900] = 5.45569602; - data[24901] = 4.87832492; - data[24902] = 3.52695142; - data[24903] = 2.17557792; - data[24904] = 0.82420459; - data[25102] = 0.94595028; - data[25103] = 2.26512621; - data[25104] = 3.58430226; - data[25105] = 4.90347855; - data[25106] = 5.20569785; - data[25107] = 3.91795207; - data[25108] = 2.63020652; - data[25109] = 1.34246063; - data[25110] = 0.05471494; - data[25307] = 0.49037894; - data[25308] = 1.74744334; - data[25309] = 3.00450763; - data[25310] = 4.26157191; - data[25311] = 5.51863620; - data[25312] = 4.39707236; - data[25313] = 3.16995848; - data[25314] = 1.94284460; - data[25315] = 0.71573065; - data[25513] = 1.14698056; - data[25514] = 2.34485767; - data[25515] = 3.54273478; - data[25516] = 4.74061165; - data[25517] = 4.95198462; - data[25518] = 3.78264743; - data[25519] = 2.61331047; - data[25520] = 1.44397374; - data[25521] = 0.27463681; - data[25718] = 0.47569509; - data[25719] = 1.61717169; - data[25720] = 2.75864848; - data[25721] = 3.90012516; - data[25722] = 5.04160160; - data[25723] = 4.45712078; - data[25724] = 3.34284059; - data[25725] = 2.22856039; - data[25726] = 1.11428020; - - for (auto & val : data) { - val /= 1000.0f; - } - - filters.data = std::move(data); - return filters; -} - -} // namespace whisper_precalc_filters diff --git a/llama/llama.cpp/tools/mtmd/mtmd-audio.h b/llama/llama.cpp/tools/mtmd/mtmd-audio.h index b7b940af..1b454337 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd-audio.h +++ b/llama/llama.cpp/tools/mtmd/mtmd-audio.h @@ -1,23 +1,15 @@ #pragma once #include "ggml.h" +#include "clip-model.h" #include #include #include -#define WHISPER_ASSERT GGML_ASSERT +#define MTMD_INTERNAL_HEADER -#define WHISPER_SAMPLE_RATE 16000 -#define WHISPER_N_FFT 400 -#define WHISPER_HOP_LENGTH 160 -#define WHISPER_CHUNK_SIZE 30 - -#define COMMON_SAMPLE_RATE 16000 - -namespace whisper_preprocessor { - -struct whisper_mel { +struct mtmd_audio_mel { int n_len; int n_len_org; int n_mel; @@ -25,23 +17,18 @@ struct whisper_mel { std::vector data; }; -struct whisper_filters { - int32_t n_mel; - int32_t n_fft; +struct mtmd_audio_preprocessor { + const clip_hparams & hparams; - std::vector data; + mtmd_audio_preprocessor(const clip_ctx * ctx): hparams(*clip_get_hparams(ctx)) {} + + virtual ~mtmd_audio_preprocessor() = default; + virtual void initialize() = 0; // NOT thread-safe + virtual bool preprocess(const float * samples, size_t n_samples, std::vector & output) = 0; }; -bool preprocess_audio( - const float * samples, - size_t n_samples, - const whisper_filters & filters, - std::vector & output); - -} // namespace whisper_preprocessor - -namespace whisper_precalc_filters { - -whisper_preprocessor::whisper_filters get_128_bins(); - -} // namespace whisper_precalc_filters +struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor { + mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} + void initialize() override; + bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; +}; diff --git a/llama/llama.cpp/tools/mtmd/mtmd-helper.cpp b/llama/llama.cpp/tools/mtmd/mtmd-helper.cpp index f0891bba..902a4b45 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd-helper.cpp +++ b/llama/llama.cpp/tools/mtmd/mtmd-helper.cpp @@ -32,6 +32,10 @@ #define STB_IMAGE_IMPLEMENTATION #include "stb/stb_image.h" +#ifdef MTMD_INTERNAL_HEADER +#error "mtmd-helper is a public library outside of mtmd. it must not include internal headers" +#endif + // // internal logging functions // diff --git a/llama/llama.cpp/tools/mtmd/mtmd.cpp b/llama/llama.cpp/tools/mtmd/mtmd.cpp index 0f5712e2..c4e905a4 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd.cpp +++ b/llama/llama.cpp/tools/mtmd/mtmd.cpp @@ -161,8 +161,7 @@ struct mtmd_context { // string template for slice image delimiters with row/col (idefics3) std::string sli_img_start_tmpl; - // for whisper, we pre-calculate the mel filter bank - whisper_preprocessor::whisper_filters w_filters; + std::unique_ptr audio_preproc; // TODO @ngxson : add timings @@ -228,7 +227,7 @@ struct mtmd_context { void init_vision() { GGML_ASSERT(ctx_v != nullptr); - use_mrope = clip_is_qwen2vl(ctx_v); + use_mrope = clip_is_mrope(ctx_v); projector_type proj = clip_get_projector_type(ctx_v); int minicpmv_version = clip_is_minicpmv(ctx_v); @@ -320,6 +319,10 @@ struct mtmd_context { img_beg = "<|image_start|>"; img_end = "<|image_end|>"; + } else if (proj == PROJECTOR_TYPE_GLM4V) { + img_beg = "<|begin_of_image|>"; + img_end = "<|end_of_image|>"; + } } @@ -327,14 +330,25 @@ struct mtmd_context { GGML_ASSERT(ctx_a != nullptr); projector_type proj = clip_get_projector_type(ctx_a); - if (clip_has_whisper_encoder(ctx_a)) { - // TODO @ngxson : check if model n_mel is 128 or 80 - w_filters = whisper_precalc_filters::get_128_bins(); - } - LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n" " https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__); + // set preprocessor + switch (proj) { + case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_QWEN25O: + case PROJECTOR_TYPE_ULTRAVOX: + case PROJECTOR_TYPE_VOXTRAL: + audio_preproc = std::make_unique(ctx_a); + break; + default: + GGML_ABORT("unsupported audio projector type"); + } + + // initialize audio preprocessor + audio_preproc->initialize(); + + // set special tokens if (proj == PROJECTOR_TYPE_QWEN2A) { // <|audio_bos|> ... (embeddings) ... <|audio_eos|> aud_beg = "<|audio_bos|>"; @@ -663,11 +677,10 @@ struct mtmd_tokenizer { } // preprocess audio - GGML_ASSERT(ctx->w_filters.n_mel); // make sure we have filter preloaded - std::vector mel_spec_chunks; + std::vector mel_spec_chunks; const float * samples = (const float *)bitmap->data.data(); size_t n_samples = bitmap->data.size() / sizeof(float); - bool ok = whisper_preprocessor::preprocess_audio(samples, n_samples, ctx->w_filters, mel_spec_chunks); + bool ok = ctx->audio_preproc->preprocess(samples, n_samples, mel_spec_chunks); if (!ok) { LOG_ERR("Unable to preprocess audio\n"); return 2; @@ -873,8 +886,7 @@ int mtmd_get_audio_bitrate(mtmd_context * ctx) { if (!ctx->ctx_a) { return -1; } - // for now, we assume that all audio models have the same bitrate - return 16000; // 16kHz + return clip_get_hparams(ctx->ctx_a)->audio_sample_rate; } // diff --git a/llama/llama.cpp/tools/mtmd/mtmd.h b/llama/llama.cpp/tools/mtmd/mtmd.h index a6a1af3b..72cec193 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd.h +++ b/llama/llama.cpp/tools/mtmd/mtmd.h @@ -22,6 +22,11 @@ * Issues related to API usage may receive lower priority support. * * For the usage, see an example in mtmd-cli.cpp + * + * For contributors: + * - Make sure the C API is aligned with the libllama C API (as in llama.h) + * - Do not include model name (e.g., qwen, gemma) in the API, use generic terms instead + * - Keep the API minimal, do not expose internal details unless necessary */ #ifdef LLAMA_SHARED diff --git a/llama/llama.go b/llama/llama.go index 49b3f56a..87844f2a 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -42,6 +42,7 @@ import ( _ "github.com/ollama/ollama/llama/llama.cpp/common" _ "github.com/ollama/ollama/llama/llama.cpp/src" _ "github.com/ollama/ollama/llama/llama.cpp/tools/mtmd" + _ "github.com/ollama/ollama/llama/llama.cpp/tools/mtmd/models" "github.com/ollama/ollama/ml" ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" ) diff --git a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch index 7a91351e..126dee34 100644 --- a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch +++ b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch @@ -23,10 +23,10 @@ problem. 8 files changed, 21 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 08681f35e..afde2f0b7 100644 +index 8547ecc84..9f37ca70c 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp -@@ -113,7 +113,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { +@@ -112,7 +112,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { if (buffer->iface.free_buffer != NULL) { buffer->iface.free_buffer(buffer); } @@ -34,7 +34,7 @@ index 08681f35e..afde2f0b7 100644 } size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { -@@ -586,6 +585,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) +@@ -591,6 +590,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) free(ctx->buffers); free(ctx); @@ -42,7 +42,7 @@ index 08681f35e..afde2f0b7 100644 } static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { -@@ -2106,6 +2106,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { +@@ -2125,6 +2125,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { GGML_ASSERT(buffer); ggml_aligned_free(buffer->context, buffer->size); @@ -54,7 +54,7 @@ index 08681f35e..afde2f0b7 100644 } static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { -@@ -2158,7 +2163,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { +@@ -2177,7 +2182,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { }; static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { @@ -64,7 +64,7 @@ index 08681f35e..afde2f0b7 100644 /* .init_tensor = */ NULL, // no initialization required /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp -index 81288464c..866758782 100644 +index da624c587..efc63e092 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -831,6 +831,7 @@ static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) { @@ -84,7 +84,7 @@ index 81288464c..866758782 100644 /** diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 279679a4e..5145c1e88 100644 +index ab0f6fe9c..6519af435 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -583,6 +583,7 @@ struct ggml_backend_cuda_buffer_context { @@ -156,10 +156,10 @@ index 18a45d2d9..89041805e 100644 static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp -index 7449a9160..e69a1ff5f 100644 +index e996d98be..84b679315 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp -@@ -355,6 +355,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { +@@ -356,6 +356,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { ggml_sycl_set_device(ctx->device); delete ctx; @@ -167,7 +167,7 @@ index 7449a9160..e69a1ff5f 100644 } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ -@@ -816,6 +817,7 @@ struct ggml_backend_sycl_split_buffer_context { +@@ -817,6 +818,7 @@ struct ggml_backend_sycl_split_buffer_context { static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; delete ctx; @@ -175,7 +175,7 @@ index 7449a9160..e69a1ff5f 100644 } static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) { -@@ -1158,6 +1160,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_ +@@ -1159,6 +1161,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_ static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_sycl_host_free(buffer->context); @@ -184,10 +184,10 @@ index 7449a9160..e69a1ff5f 100644 static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index c6f5809cc..c801d2fd2 100644 +index 34ec09d40..120191ca0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -12271,6 +12271,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { +@@ -12365,6 +12365,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; ggml_vk_destroy_buffer(ctx->dev_buffer); delete ctx; @@ -195,7 +195,7 @@ index c6f5809cc..c801d2fd2 100644 } static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { -@@ -12414,6 +12415,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe +@@ -12508,6 +12509,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); ggml_vk_host_free(vk_instance.devices[0], buffer->context); diff --git a/llama/patches/0002-pretokenizer.patch b/llama/patches/0002-pretokenizer.patch index 7bb5f48a..9cee5c56 100644 --- a/llama/patches/0002-pretokenizer.patch +++ b/llama/patches/0002-pretokenizer.patch @@ -10,7 +10,7 @@ logs instead of throwing an error 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index e2cca66e4..8246a0a14 100644 +index 7b01a2edf..63250cdf1 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1825,16 +1825,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { @@ -31,7 +31,7 @@ index e2cca66e4..8246a0a14 100644 pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( tokenizer_pre == "llama3" || -@@ -2014,7 +2005,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { +@@ -2015,7 +2006,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2; clean_spaces = false; } else { diff --git a/llama/patches/0003-clip-unicode.patch b/llama/patches/0003-clip-unicode.patch index d05b01d9..73d10732 100644 --- a/llama/patches/0003-clip-unicode.patch +++ b/llama/patches/0003-clip-unicode.patch @@ -10,7 +10,7 @@ filesystems for paths that include wide characters 1 file changed, 39 insertions(+) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp -index 3ed08a0fe..6be1470ad 100644 +index 35e3aef0a..84a3796b5 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -24,6 +24,19 @@ @@ -32,8 +32,8 @@ index 3ed08a0fe..6be1470ad 100644 + struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL}; - enum ffn_op_type { -@@ -3257,7 +3270,29 @@ struct clip_model_loader { + //#define CLIP_DEBUG_FUNCTIONS +@@ -1619,7 +1632,29 @@ struct clip_model_loader { { std::vector read_buf; @@ -63,7 +63,7 @@ index 3ed08a0fe..6be1470ad 100644 if (!fin) { throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str())); } -@@ -3284,7 +3319,11 @@ struct clip_model_loader { +@@ -1646,7 +1681,11 @@ struct clip_model_loader { ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); } } diff --git a/llama/patches/0004-solar-pro.patch b/llama/patches/0004-solar-pro.patch index 7adce420..f267356e 100644 --- a/llama/patches/0004-solar-pro.patch +++ b/llama/patches/0004-solar-pro.patch @@ -6,7 +6,7 @@ Subject: [PATCH] solar-pro adds support for the Solar Pro architecture --- src/CMakeLists.txt | 1 + - src/llama-arch.cpp | 21 +++++ + src/llama-arch.cpp | 20 +++++ src/llama-arch.h | 3 + src/llama-hparams.cpp | 8 ++ src/llama-hparams.h | 5 ++ @@ -15,7 +15,7 @@ adds support for the Solar Pro architecture src/llama-model.h | 3 + src/models/models.h | 5 ++ src/models/solar.cpp | 158 +++++++++++++++++++++++++++++++++++++ - 10 files changed, 253 insertions(+), 1 deletion(-) + 10 files changed, 252 insertions(+), 1 deletion(-) create mode 100644 src/models/solar.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt @@ -31,10 +31,10 @@ index 4192af7c0..bd44d73e7 100644 models/starcoder.cpp models/starcoder2.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp -index 64ad1b776..a5fe4f66c 100644 +index 8caf80afc..2ce8ffec0 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp -@@ -85,6 +85,7 @@ static const std::map LLM_ARCH_NAMES = { +@@ -87,6 +87,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_GRANITE_HYBRID, "granitehybrid" }, { LLM_ARCH_CHAMELEON, "chameleon" }, @@ -42,7 +42,7 @@ index 64ad1b776..a5fe4f66c 100644 { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" }, -@@ -206,6 +207,7 @@ static const std::map LLM_KV_NAMES = { +@@ -208,6 +209,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, @@ -50,32 +50,38 @@ index 64ad1b776..a5fe4f66c 100644 { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, -@@ -2025,6 +2027,24 @@ static const std::map> LLM_TENSOR_N - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, - }, - }, -+ { -+ LLM_ARCH_SOLAR, -+ { -+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, -+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, -+ { LLM_TENSOR_OUTPUT, "output" }, -+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, -+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, -+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, -+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, -+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, -+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, -+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, -+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, -+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, -+ { LLM_TENSOR_BSKCN_TV, "bskcn_tv" }, -+ }, -+ }, - { - LLM_ARCH_WAVTOKENIZER_DEC, - { -@@ -2710,6 +2730,7 @@ static const std::map LLM_TENSOR_INFOS = { +@@ -339,6 +341,7 @@ static const std::map LLM_TENSOR_NAMES = { + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, ++ { LLM_TENSOR_BSKCN_TV, "bskcn_tv" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, +@@ -2176,6 +2179,22 @@ static std::set llm_get_tensor_names(llm_arch arch) { + return { + LLM_TENSOR_TOKEN_EMBD, + }; ++ case LLM_ARCH_SOLAR: ++ return { ++ LLM_TENSOR_TOKEN_EMBD, ++ LLM_TENSOR_OUTPUT_NORM, ++ LLM_TENSOR_OUTPUT, ++ LLM_TENSOR_ATTN_NORM, ++ LLM_TENSOR_ATTN_Q, ++ LLM_TENSOR_ATTN_K, ++ LLM_TENSOR_ATTN_V, ++ LLM_TENSOR_ATTN_OUT, ++ LLM_TENSOR_FFN_NORM, ++ LLM_TENSOR_FFN_GATE, ++ LLM_TENSOR_FFN_DOWN, ++ LLM_TENSOR_FFN_UP, ++ LLM_TENSOR_BSKCN_TV, ++ }; + default: + GGML_ABORT("unknown architecture for tensor mapping"); + } +@@ -2344,6 +2363,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, @@ -84,10 +90,10 @@ index 64ad1b776..a5fe4f66c 100644 {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h -index e11318002..ec9e3a6df 100644 +index 6cbf9b1f8..14d461c76 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h -@@ -89,6 +89,7 @@ enum llm_arch { +@@ -91,6 +91,7 @@ enum llm_arch { LLM_ARCH_GRANITE_MOE, LLM_ARCH_GRANITE_HYBRID, LLM_ARCH_CHAMELEON, @@ -95,7 +101,7 @@ index e11318002..ec9e3a6df 100644 LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, LLM_ARCH_BAILINGMOE, -@@ -210,6 +211,7 @@ enum llm_kv { +@@ -212,6 +213,7 @@ enum llm_kv { LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, LLM_KV_ATTENTION_TEMPERATURE_SCALE, @@ -103,7 +109,7 @@ index e11318002..ec9e3a6df 100644 LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, -@@ -462,6 +464,7 @@ enum llm_tensor { +@@ -465,6 +467,7 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, @@ -112,10 +118,10 @@ index e11318002..ec9e3a6df 100644 LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_NORM, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp -index 8cdbaf69f..41127bf91 100644 +index fe1fa4341..aabff2f06 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp -@@ -161,6 +161,14 @@ uint32_t llama_hparams::n_pos_per_embd() const { +@@ -163,6 +163,14 @@ uint32_t llama_hparams::n_pos_per_embd() const { return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1; } @@ -131,10 +137,10 @@ index 8cdbaf69f..41127bf91 100644 if (il < n_layer) { return swa_layers[il]; diff --git a/src/llama-hparams.h b/src/llama-hparams.h -index 6eff334a5..a778fc3cf 100644 +index f6e95b5d2..c6e673276 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h -@@ -64,6 +64,8 @@ struct llama_hparams { +@@ -65,6 +65,8 @@ struct llama_hparams { std::array n_head_kv_arr; std::array n_ff_arr; @@ -143,7 +149,7 @@ index 6eff334a5..a778fc3cf 100644 uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; uint32_t n_lora_kv = 0; -@@ -256,6 +258,9 @@ struct llama_hparams { +@@ -259,6 +261,9 @@ struct llama_hparams { uint32_t n_pos_per_embd() const; @@ -154,7 +160,7 @@ index 6eff334a5..a778fc3cf 100644 bool has_kv(uint32_t il) const; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp -index aa3a65f87..ee303bd58 100644 +index ca2ea2461..8916a6242 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -466,7 +466,7 @@ namespace GGUFMeta { @@ -167,10 +173,10 @@ index aa3a65f87..ee303bd58 100644 llama_model_loader::llama_model_loader( const std::string & fname, diff --git a/src/llama-model.cpp b/src/llama-model.cpp -index 04fccc979..3c503b424 100644 +index ae8207ee1..00cd579e0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp -@@ -1975,6 +1975,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { +@@ -1995,6 +1995,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; @@ -192,7 +198,7 @@ index 04fccc979..3c503b424 100644 case LLM_ARCH_WAVTOKENIZER_DEC: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); -@@ -5401,6 +5416,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) { +@@ -5429,6 +5444,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -227,7 +233,7 @@ index 04fccc979..3c503b424 100644 layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); -@@ -7480,6 +7523,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { +@@ -7534,6 +7577,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; @@ -238,7 +244,7 @@ index 04fccc979..3c503b424 100644 case LLM_ARCH_WAVTOKENIZER_DEC: { llm = std::make_unique(*this, params); -@@ -7743,6 +7790,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { +@@ -7798,6 +7845,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_CHAMELEON: @@ -247,7 +253,7 @@ index 04fccc979..3c503b424 100644 case LLM_ARCH_NEO_BERT: case LLM_ARCH_SMOLLM3: diff --git a/src/llama-model.h b/src/llama-model.h -index f8342cf2c..cbf4e1bfa 100644 +index c6eb95318..b378b23ec 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -76,6 +76,7 @@ enum llm_type { @@ -258,7 +264,7 @@ index f8342cf2c..cbf4e1bfa 100644 LLM_TYPE_26B, LLM_TYPE_27B, LLM_TYPE_30B, -@@ -404,6 +405,8 @@ struct llama_layer { +@@ -405,6 +406,8 @@ struct llama_layer { struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_eps = nullptr; @@ -268,7 +274,7 @@ index f8342cf2c..cbf4e1bfa 100644 struct llama_layer_convnext convnext; diff --git a/src/models/models.h b/src/models/models.h -index 6494f5450..e0aec822c 100644 +index ffb36acc6..6d84a185d 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -515,6 +515,11 @@ struct llm_build_smollm3 : public llm_graph_context { diff --git a/llama/patches/0005-fix-deepseek-deseret-regex.patch b/llama/patches/0005-fix-deepseek-deseret-regex.patch index 22f3cd9f..9aa2ae46 100644 --- a/llama/patches/0005-fix-deepseek-deseret-regex.patch +++ b/llama/patches/0005-fix-deepseek-deseret-regex.patch @@ -12,7 +12,7 @@ regex 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index 8246a0a14..dfba7778b 100644 +index 63250cdf1..dd86a1745 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -299,7 +299,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { diff --git a/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch b/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch index 83061168..80ff0592 100644 --- a/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch +++ b/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch @@ -8,10 +8,10 @@ Subject: [PATCH] maintain ordering for rules for grammar 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp -index c3b4e5d9d..6be552826 100644 +index 2f67c74d7..acf00e2d2 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp -@@ -310,7 +310,7 @@ private: +@@ -311,7 +311,7 @@ private: friend std::string build_grammar(const std::function & cb, const common_grammar_options & options); std::function _fetch_json; bool _dotall; diff --git a/llama/patches/0010-fix-string-arr-kv-loading.patch b/llama/patches/0010-fix-string-arr-kv-loading.patch index 622783d9..63acee83 100644 --- a/llama/patches/0010-fix-string-arr-kv-loading.patch +++ b/llama/patches/0010-fix-string-arr-kv-loading.patch @@ -53,7 +53,7 @@ index b165d8bdc..f91d4faba 100644 } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index dfba7778b..f72f321b9 100644 +index dd86a1745..d63ce9c84 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1781,9 +1781,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { diff --git a/llama/patches/0011-ollama-debug-tensor.patch b/llama/patches/0011-ollama-debug-tensor.patch index 8680c91d..a2a4eb6b 100644 --- a/llama/patches/0011-ollama-debug-tensor.patch +++ b/llama/patches/0011-ollama-debug-tensor.patch @@ -8,7 +8,7 @@ Subject: [PATCH] ollama debug tensor 1 file changed, 6 insertions(+) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c -index b468b115a..bb65985b4 100644 +index a59b51893..53891a91f 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -15,6 +15,8 @@ @@ -20,7 +20,7 @@ index b468b115a..bb65985b4 100644 #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) -@@ -2928,6 +2930,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { +@@ -2945,6 +2947,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_compute_forward(¶ms, node); diff --git a/llama/patches/0014-graph-memory-reporting-on-failure.patch b/llama/patches/0014-graph-memory-reporting-on-failure.patch index aa466862..0b818ec8 100644 --- a/llama/patches/0014-graph-memory-reporting-on-failure.patch +++ b/llama/patches/0014-graph-memory-reporting-on-failure.patch @@ -11,10 +11,10 @@ Subject: [PATCH] graph memory reporting on failure 4 files changed, 40 insertions(+), 3 deletions(-) diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h -index 2cb150fd2..7ab3f0192 100644 +index 78aa059dd..7fa8403b3 100644 --- a/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h -@@ -65,6 +65,7 @@ GGML_API bool ggml_gallocr_reserve_n( +@@ -72,6 +72,7 @@ GGML_API bool ggml_gallocr_reserve_n( GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph); GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id); @@ -23,10 +23,10 @@ index 2cb150fd2..7ab3f0192 100644 // Utils // Create a buffer and allocate all the tensors in a ggml_context diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index f1b740785..c54ff98bf 100644 +index 4ed5f3577..a7ebe5dcd 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -318,6 +318,7 @@ extern "C" { +@@ -319,6 +319,7 @@ extern "C" { GGML_API ggml_backend_buffer_type_t ggml_backend_sched_get_buffer_type(ggml_backend_sched_t sched, ggml_backend_t backend); GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); @@ -35,10 +35,10 @@ index f1b740785..c54ff98bf 100644 GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c -index a5995fdc2..dbfd8b5b2 100644 +index 41419b617..73b39bfea 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c -@@ -494,6 +494,7 @@ struct node_alloc { +@@ -485,6 +485,7 @@ struct node_alloc { struct ggml_gallocr { ggml_backend_buffer_type_t * bufts; // [n_buffers] struct vbuffer ** buffers; // [n_buffers] @@ -46,7 +46,7 @@ index a5995fdc2..dbfd8b5b2 100644 struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers] int n_buffers; -@@ -517,6 +518,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs +@@ -508,6 +509,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs galloc->buffers = calloc(n_bufs, sizeof(struct vbuffer *)); GGML_ASSERT(galloc->buffers != NULL); @@ -56,7 +56,7 @@ index a5995fdc2..dbfd8b5b2 100644 galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *)); GGML_ASSERT(galloc->buf_tallocs != NULL); -@@ -584,6 +588,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { +@@ -575,6 +579,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { ggml_hash_set_free(&galloc->hash_set); free(galloc->hash_values); free(galloc->bufts); @@ -64,7 +64,7 @@ index a5995fdc2..dbfd8b5b2 100644 free(galloc->buffers); free(galloc->buf_tallocs); free(galloc->node_allocs); -@@ -899,6 +904,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c +@@ -904,6 +909,8 @@ static bool ggml_gallocr_reserve_n_impl( } } @@ -73,18 +73,19 @@ index a5995fdc2..dbfd8b5b2 100644 // reallocate buffers if needed for (int i = 0; i < galloc->n_buffers; i++) { // if the buffer type is used multiple times, we reuse the same buffer -@@ -933,14 +940,19 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c - #endif - ggml_vbuffer_free(galloc->buffers[i]); - galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); -- if (galloc->buffers[i] == NULL) { -+ if (galloc->buffers[i]) { -+ galloc->buffer_sizes[i] = ggml_vbuffer_size(galloc->buffers[i]); -+ } else { - GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); -- return false; -+ galloc->buffer_sizes[i] = new_size; -+ success = false; +@@ -940,15 +947,20 @@ static bool ggml_gallocr_reserve_n_impl( + galloc->buffers[i] = NULL; + } else { + galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); +- if (galloc->buffers[i] == NULL) { ++ if (galloc->buffers[i]) { ++ galloc->buffer_sizes[i] = ggml_vbuffer_size(galloc->buffers[i]); ++ } else { + GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); +- return false; ++ galloc->buffer_sizes[i] = new_size; ++ success = false; + } } + } else { + galloc->buffer_sizes[i] = ggml_vbuffer_size(galloc->buffers[i]); @@ -95,8 +96,8 @@ index a5995fdc2..dbfd8b5b2 100644 + return success; } - bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { -@@ -1095,6 +1107,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { + void ggml_gallocr_reserve_n_size( +@@ -1118,6 +1130,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { return ggml_vbuffer_size(galloc->buffers[buffer_id]); } @@ -120,10 +121,10 @@ index a5995fdc2..dbfd8b5b2 100644 static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index afde2f0b7..dbf8486a0 100644 +index 9f37ca70c..1459d16dd 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp -@@ -1840,6 +1840,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe +@@ -1859,6 +1859,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe return ggml_gallocr_get_buffer_size(sched->galloc, backend_index); } diff --git a/llama/patches/0015-ggml-Export-GPU-UUIDs.patch b/llama/patches/0015-ggml-Export-GPU-UUIDs.patch index 1ae032ab..ec0dfdc6 100644 --- a/llama/patches/0015-ggml-Export-GPU-UUIDs.patch +++ b/llama/patches/0015-ggml-Export-GPU-UUIDs.patch @@ -10,7 +10,7 @@ Subject: [PATCH] ggml: Export GPU UUIDs 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index c54ff98bf..229bf387b 100644 +index a7ebe5dcd..03557bb31 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -158,6 +158,7 @@ extern "C" { @@ -22,7 +22,7 @@ index c54ff98bf..229bf387b 100644 size_t memory_total; // device type diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 5145c1e88..f641c1016 100644 +index 6519af435..c9d3a2b03 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -189,6 +189,51 @@ static int ggml_cuda_parse_id(char devName[]) { @@ -136,7 +136,7 @@ index 5145c1e88..f641c1016 100644 props->type = ggml_backend_cuda_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); -@@ -4833,6 +4887,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -4834,6 +4888,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; diff --git a/llama/patches/0016-add-C-API-for-mtmd_input_text.patch b/llama/patches/0016-add-C-API-for-mtmd_input_text.patch index 19c4d25d..8205e2cb 100644 --- a/llama/patches/0016-add-C-API-for-mtmd_input_text.patch +++ b/llama/patches/0016-add-C-API-for-mtmd_input_text.patch @@ -10,7 +10,7 @@ Signed-off-by: Gabe Goodhart 2 files changed, 13 insertions(+) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp -index d06fa42e6..0f5712e21 100644 +index 2638fe4fc..c4e905a4e 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -87,6 +87,16 @@ enum mtmd_slice_tmpl { @@ -31,10 +31,10 @@ index d06fa42e6..0f5712e21 100644 return "<__media__>"; } diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h -index b3df24c29..a6a1af3b8 100644 +index 9f7e861e9..72cec1937 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h -@@ -75,6 +75,9 @@ typedef struct mtmd_input_chunk mtmd_input_chunk; +@@ -80,6 +80,9 @@ typedef struct mtmd_input_chunk mtmd_input_chunk; typedef struct mtmd_input_chunks mtmd_input_chunks; typedef struct mtmd_input_text mtmd_input_text; diff --git a/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch b/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch index a788a562..010d609e 100644 --- a/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch +++ b/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch @@ -8,10 +8,10 @@ Subject: [PATCH] no power throttling win32 with gnuc 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c -index bb65985b4..47089a62e 100644 +index 53891a91f..8d4851312 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c -@@ -2464,7 +2464,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { +@@ -2479,7 +2479,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { // Newer Windows 11 versions aggresively park (offline) CPU cores and often place // all our threads onto the first 4 cores which results in terrible performance with // n_threads > 4 diff --git a/llama/patches/0018-ggml-Add-batch-size-hint.patch b/llama/patches/0018-ggml-Add-batch-size-hint.patch index cef00be5..5b66ee36 100644 --- a/llama/patches/0018-ggml-Add-batch-size-hint.patch +++ b/llama/patches/0018-ggml-Add-batch-size-hint.patch @@ -20,7 +20,7 @@ consistent performance. 8 files changed, 58 insertions(+), 32 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 229bf387b..2763f2bd6 100644 +index 03557bb31..93c95602d 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -98,7 +98,7 @@ extern "C" { @@ -40,8 +40,8 @@ index 229bf387b..2763f2bd6 100644 + GGML_API void ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size); + // Initialize backend buffers from a measure graph + GGML_API void ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes); GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success - diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 6792ba986..0f5b03cef 100644 --- a/ggml/src/ggml-backend-impl.h @@ -58,10 +58,10 @@ index 6792ba986..0f5b03cef 100644 // (optional) event synchronization // record an event on this stream diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index dbf8486a0..312ca873c 100644 +index 1459d16dd..498186a7c 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp -@@ -348,14 +348,14 @@ enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_ba +@@ -353,14 +353,14 @@ enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_ba } enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { @@ -79,7 +79,7 @@ index dbf8486a0..312ca873c 100644 } bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { -@@ -722,6 +722,8 @@ struct ggml_backend_sched { +@@ -727,6 +727,8 @@ struct ggml_backend_sched { bool op_offload; @@ -88,7 +88,7 @@ index dbf8486a0..312ca873c 100644 int debug; // used for debugging graph reallocations [GGML_SCHED_DEBUG_REALLOC] -@@ -820,7 +822,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st +@@ -825,7 +827,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor); // check if a backend with higher prio wants to offload the op @@ -97,7 +97,7 @@ index dbf8486a0..312ca873c 100644 for (int b = 0; b < src_backend_id; b++) { if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) { SET_CAUSE(tensor, "1.off"); -@@ -1572,7 +1574,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s +@@ -1577,7 +1579,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } if (!sched->callback_eval) { @@ -106,7 +106,7 @@ index dbf8486a0..312ca873c 100644 if (ec != GGML_STATUS_SUCCESS) { return ec; } -@@ -1594,7 +1596,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s +@@ -1599,7 +1601,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1); @@ -115,7 +115,7 @@ index dbf8486a0..312ca873c 100644 if (ec != GGML_STATUS_SUCCESS) { return ec; } -@@ -1684,6 +1686,7 @@ ggml_backend_sched_t ggml_backend_sched_new( +@@ -1689,6 +1691,7 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); sched->op_offload = op_offload; @@ -123,7 +123,7 @@ index dbf8486a0..312ca873c 100644 ggml_backend_sched_reset(sched); -@@ -1715,6 +1718,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { +@@ -1720,6 +1723,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { free(sched); } @@ -156,7 +156,7 @@ index 5b888cdd8..88d088952 100644 static struct ggml_backend_i blas_backend_i = { diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp -index 3191faaa4..32f14c811 100644 +index f4713a421..92ba577a5 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -164,7 +164,7 @@ static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backe @@ -178,7 +178,7 @@ index 3191faaa4..32f14c811 100644 static const struct ggml_backend_i ggml_backend_cpu_i = { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index f641c1016..17062697b 100644 +index c9d3a2b03..25548629d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2901,7 +2901,7 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { @@ -278,10 +278,10 @@ index 8fc1c2fb5..ba95b4acc 100644 static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index c801d2fd2..b2c0d0cee 100644 +index 120191ca0..5349bce24 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -13006,7 +13006,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru +@@ -13099,7 +13099,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru return num_adds; } @@ -290,7 +290,7 @@ index c801d2fd2..b2c0d0cee 100644 VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; -@@ -13241,6 +13241,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg +@@ -13334,6 +13334,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg return GGML_STATUS_SUCCESS; UNUSED(backend); diff --git a/llama/patches/0019-fix-mtmd-audio.cpp-build-on-windows.patch b/llama/patches/0019-fix-mtmd-audio.cpp-build-on-windows.patch index 761c18fc..2c4e3050 100644 --- a/llama/patches/0019-fix-mtmd-audio.cpp-build-on-windows.patch +++ b/llama/patches/0019-fix-mtmd-audio.cpp-build-on-windows.patch @@ -8,7 +8,7 @@ Subject: [PATCH] fix mtmd-audio.cpp build on windows 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp -index 4d053895c..84bdc2777 100644 +index f68829a61..2024d3d37 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -1,6 +1,6 @@ diff --git a/llama/patches/0020-ggml-No-alloc-mode.patch b/llama/patches/0020-ggml-No-alloc-mode.patch index 95962f82..19f5f7e7 100644 --- a/llama/patches/0020-ggml-No-alloc-mode.patch +++ b/llama/patches/0020-ggml-No-alloc-mode.patch @@ -10,13 +10,13 @@ must be recreated with no-alloc set to false before loading data. --- ggml/include/ggml-backend.h | 1 + ggml/src/ggml-backend-impl.h | 16 +++ - ggml/src/ggml-backend.cpp | 72 +++++++++- + ggml/src/ggml-backend.cpp | 75 ++++++++++- ggml/src/ggml-cuda/common.cuh | 62 ++++++++- ggml/src/ggml-cuda/ggml-cuda.cu | 224 ++++++++++++++++++++++++++------ - 5 files changed, 331 insertions(+), 44 deletions(-) + 5 files changed, 333 insertions(+), 45 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 2763f2bd6..b3b5b356a 100644 +index 93c95602d..dbbb61d9c 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -305,6 +305,7 @@ extern "C" { @@ -75,13 +75,19 @@ index 0f5b03cef..7bdf9d81f 100644 struct ggml_backend { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 312ca873c..4092dfe8a 100644 +index 498186a7c..7746e8b92 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp -@@ -41,6 +41,19 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t +@@ -36,11 +36,25 @@ const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) { + } + + ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +- GGML_ASSERT(buft); + if (size == 0) { + // return a dummy buffer for zero-sized allocations return ggml_backend_buffer_init(buft, {}, NULL, 0); } - ++ + if (buft->no_alloc) { + ggml_backend_buffer_t buf; + @@ -95,10 +101,11 @@ index 312ca873c..4092dfe8a 100644 + return buf; + } + - GGML_ASSERT(buft); ++ GGML_ASSERT(buft); return buft->iface.alloc_buffer(buft, size); } -@@ -95,7 +108,8 @@ ggml_backend_buffer_t ggml_backend_buffer_init( + +@@ -94,7 +108,8 @@ ggml_backend_buffer_t ggml_backend_buffer_init( /* .buft = */ buft, /* .context = */ context, /* .size = */ size, @@ -108,7 +115,7 @@ index 312ca873c..4092dfe8a 100644 }; return buffer; -@@ -127,6 +141,12 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { +@@ -126,6 +141,12 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { return NULL; } @@ -118,10 +125,10 @@ index 312ca873c..4092dfe8a 100644 + return (void *)ggml_backend_buffer_get_alignment(buffer); + } + - void * base = buffer->iface.get_base(buffer); - - GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL"); -@@ -731,6 +751,12 @@ struct ggml_backend_sched { + // FIXME JG: a multi_buffer has a non-zero size, according to the above comment get_base is not optional, + // I don't know whether the above comment is correct + if (!buffer->iface.get_base) { +@@ -736,6 +757,12 @@ struct ggml_backend_sched { int debug_realloc; int debug_graph_size; int debug_prev_graph_size; @@ -134,7 +141,7 @@ index 312ca873c..4092dfe8a 100644 }; #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) -@@ -1630,6 +1656,17 @@ ggml_backend_sched_t ggml_backend_sched_new( +@@ -1635,6 +1662,17 @@ ggml_backend_sched_t ggml_backend_sched_new( size_t graph_size, bool parallel, bool op_offload) { @@ -152,7 +159,7 @@ index 312ca873c..4092dfe8a 100644 GGML_ASSERT(n_backends > 0); GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU); -@@ -1682,11 +1719,14 @@ ggml_backend_sched_t ggml_backend_sched_new( +@@ -1687,11 +1725,14 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->events[b][c] = ggml_backend_event_new(backends[b]->device); } } @@ -167,7 +174,7 @@ index 312ca873c..4092dfe8a 100644 ggml_backend_sched_reset(sched); -@@ -1701,6 +1741,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { +@@ -1706,6 +1747,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { for (int c = 0; c < sched->n_copies; c++) { ggml_backend_event_free(sched->events[b][c]); } @@ -178,7 +185,7 @@ index 312ca873c..4092dfe8a 100644 } ggml_gallocr_free(sched->galloc); ggml_free(sched->ctx); -@@ -1746,6 +1790,24 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * +@@ -1765,6 +1810,24 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * return false; } @@ -203,7 +210,7 @@ index 312ca873c..4092dfe8a 100644 ggml_backend_sched_reset(sched); return true; -@@ -1851,7 +1913,13 @@ size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, +@@ -1870,7 +1933,13 @@ size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); @@ -219,7 +226,7 @@ index 312ca873c..4092dfe8a 100644 void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh -index c4529f5d9..8b0fb5d42 100644 +index 9fcb2f9fd..e800ee8f6 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -37,6 +37,41 @@ @@ -264,7 +271,7 @@ index c4529f5d9..8b0fb5d42 100644 #define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) -@@ -938,6 +973,9 @@ struct ggml_cuda_pool { +@@ -941,6 +976,9 @@ struct ggml_cuda_pool { virtual void * alloc(size_t size, size_t * actual_size) = 0; virtual void free(void * ptr, size_t size) = 0; @@ -274,7 +281,7 @@ index c4529f5d9..8b0fb5d42 100644 }; template -@@ -1229,11 +1267,15 @@ struct ggml_backend_cuda_context { +@@ -1232,11 +1270,15 @@ struct ggml_backend_cuda_context { // pool std::unique_ptr pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; @@ -292,7 +299,7 @@ index c4529f5d9..8b0fb5d42 100644 } return *pools[device][curr_stream_no]; } -@@ -1241,6 +1283,22 @@ struct ggml_backend_cuda_context { +@@ -1244,6 +1286,22 @@ struct ggml_backend_cuda_context { ggml_cuda_pool & pool() { return pool(device); } @@ -316,7 +323,7 @@ index c4529f5d9..8b0fb5d42 100644 struct ggml_cuda_mm_fusion_args_host { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 17062697b..ede1d089a 100644 +index 25548629d..eeaae3fe4 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -365,6 +365,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { diff --git a/llama/patches/0021-decode-disable-output_all.patch b/llama/patches/0021-decode-disable-output_all.patch index 7de5e378..20001bd9 100644 --- a/llama/patches/0021-decode-disable-output_all.patch +++ b/llama/patches/0021-decode-disable-output_all.patch @@ -8,10 +8,10 @@ Subject: [PATCH] decode: disable output_all 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp -index 417140071..87f407f99 100644 +index 8786d4ee3..9e6998272 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp -@@ -999,8 +999,7 @@ int llama_context::decode(const llama_batch & batch_inp) { +@@ -1051,8 +1051,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const int64_t n_vocab = vocab.n_tokens(); const int64_t n_embd = hparams.n_embd_inp(); diff --git a/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch b/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch index 1bcc0e31..3197f94e 100644 --- a/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch +++ b/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch @@ -16,7 +16,7 @@ unused then it can be reset to free these data structures. 6 files changed, 32 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index b3b5b356a..69223c488 100644 +index dbbb61d9c..92ca32a4b 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -178,6 +178,7 @@ extern "C" { @@ -43,10 +43,10 @@ index 7bdf9d81f..21b35ac5c 100644 struct ggml_backend_device { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 4092dfe8a..a1a19fe51 100644 +index 7746e8b92..189e97170 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp -@@ -526,6 +526,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par +@@ -532,6 +532,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par return device->iface.init_backend(device, params); } @@ -62,7 +62,7 @@ index 4092dfe8a..a1a19fe51 100644 GGML_ASSERT(device); return device->iface.get_buffer_type(device); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index ede1d089a..ec63cadab 100644 +index eeaae3fe4..6852d2e20 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -113,6 +113,11 @@ int ggml_cuda_get_device() { @@ -89,7 +89,7 @@ index ede1d089a..ec63cadab 100644 bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY -@@ -4907,6 +4915,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g +@@ -4908,6 +4916,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); } @@ -101,7 +101,7 @@ index ede1d089a..ec63cadab 100644 static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .get_name = */ ggml_backend_cuda_device_get_name, /* .get_description = */ ggml_backend_cuda_device_get_description, -@@ -4923,6 +4936,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { +@@ -4924,6 +4937,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .event_new = */ ggml_backend_cuda_device_event_new, /* .event_free = */ ggml_backend_cuda_device_event_free, /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, @@ -110,10 +110,10 @@ index ede1d089a..ec63cadab 100644 // backend reg diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h -index b7d6edf7f..b987d7aeb 100644 +index 951a88d56..4e162258d 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h -@@ -45,6 +45,7 @@ +@@ -49,6 +49,7 @@ #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceProp hipDeviceProp_t @@ -122,10 +122,10 @@ index b7d6edf7f..b987d7aeb 100644 #define cudaError_t hipError_t #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled diff --git a/src/llama.cpp b/src/llama.cpp -index ab2e9868a..74c49e651 100644 +index f69964b6d..759152b76 100644 --- a/src/llama.cpp +++ b/src/llama.cpp -@@ -270,10 +270,12 @@ static struct llama_model * llama_model_load_from_file_impl( +@@ -921,10 +921,12 @@ static struct llama_model * llama_model_load_from_file_impl( for (auto * dev : model->devices) { ggml_backend_dev_props props; ggml_backend_dev_get_props(dev, &props); diff --git a/llama/patches/0024-GPU-discovery-enhancements.patch b/llama/patches/0024-GPU-discovery-enhancements.patch index 86f57122..11106f4e 100644 --- a/llama/patches/0024-GPU-discovery-enhancements.patch +++ b/llama/patches/0024-GPU-discovery-enhancements.patch @@ -28,7 +28,7 @@ fix vulkan PCI ID and ID handling create mode 100644 ggml/src/mem_nvml.cpp diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 69223c488..6510e0cba 100644 +index 92ca32a4b..6ad583f09 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -169,6 +169,12 @@ extern "C" { @@ -58,7 +58,7 @@ index d55aed348..99ae293cc 100644 set_target_properties(ggml-base PROPERTIES diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index ec63cadab..cd71902df 100644 +index 6852d2e20..48cdb1dcf 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -267,6 +267,16 @@ static ggml_cuda_device_info ggml_cuda_init() { @@ -159,7 +159,7 @@ index ec63cadab..cd71902df 100644 bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY bool events = false; -@@ -5046,6 +5102,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -5047,6 +5103,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { std::lock_guard lock(mutex); if (!initialized) { ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; @@ -167,7 +167,7 @@ index ec63cadab..cd71902df 100644 for (int i = 0; i < ggml_cuda_info().device_count; i++) { ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; -@@ -5061,6 +5118,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -5062,6 +5119,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); dev_ctx->pci_bus_id = pci_bus_id; @@ -183,7 +183,7 @@ index ec63cadab..cd71902df 100644 /* .iface = */ ggml_backend_cuda_device_interface, /* .reg = */ ®, diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h -index b987d7aeb..5ad5623ae 100644 +index 4e162258d..d89e35a8e 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -5,6 +5,8 @@ @@ -195,7 +195,7 @@ index b987d7aeb..5ad5623ae 100644 #if defined(GGML_HIP_ROCWMMA_FATTN) #include -@@ -47,6 +49,7 @@ +@@ -51,6 +53,7 @@ #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceReset hipDeviceReset #define cudaDeviceSynchronize hipDeviceSynchronize @@ -243,7 +243,7 @@ index ba95b4acc..f6f8f7a10 100644 /* .async = */ true, /* .host_buffer = */ false, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index b2c0d0cee..d9f4d34f5 100644 +index 5349bce24..d43d46d1d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -236,6 +236,7 @@ class vk_memory_logger; @@ -254,7 +254,7 @@ index b2c0d0cee..d9f4d34f5 100644 static constexpr uint32_t mul_mat_vec_max_cols = 8; static constexpr uint32_t p021_max_gqa_ratio = 8; -@@ -12256,6 +12257,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ +@@ -12350,6 +12351,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ snprintf(description, description_size, "%s", props.deviceName.data()); } @@ -284,7 +284,7 @@ index b2c0d0cee..d9f4d34f5 100644 // backend interface #define UNUSED GGML_UNUSED -@@ -13535,15 +13559,72 @@ void ggml_backend_vk_get_device_description(int device, char * description, size +@@ -13628,15 +13652,72 @@ void ggml_backend_vk_get_device_description(int device, char * description, size ggml_vk_get_device_description(dev_idx, description, description_size); } @@ -361,7 +361,7 @@ index b2c0d0cee..d9f4d34f5 100644 if (membudget_supported) { memprops.pNext = &budgetprops; -@@ -13595,8 +13676,13 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { +@@ -13688,8 +13769,13 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { } } @@ -376,7 +376,7 @@ index b2c0d0cee..d9f4d34f5 100644 } vk::PhysicalDeviceProperties2 props = {}; -@@ -13613,19 +13699,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { +@@ -13706,19 +13792,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { char pci_bus_id[16] = {}; snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); @@ -410,7 +410,7 @@ index b2c0d0cee..d9f4d34f5 100644 static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -@@ -13637,9 +13728,14 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de +@@ -13730,9 +13821,14 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de return ctx->description.c_str(); } @@ -426,7 +426,7 @@ index b2c0d0cee..d9f4d34f5 100644 } static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { -@@ -13663,8 +13759,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml +@@ -13756,8 +13852,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); @@ -437,7 +437,7 @@ index b2c0d0cee..d9f4d34f5 100644 ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { /* .async = */ false, -@@ -13672,6 +13769,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml +@@ -13765,6 +13862,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml /* .buffer_from_host_ptr = */ false, /* .events = */ false, }; @@ -451,7 +451,7 @@ index b2c0d0cee..d9f4d34f5 100644 } static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { -@@ -14236,6 +14340,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -14331,6 +14435,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { @@ -460,7 +460,7 @@ index b2c0d0cee..d9f4d34f5 100644 for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; char desc[256]; -@@ -14244,12 +14350,41 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -14339,12 +14445,41 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->name = GGML_VK_NAME + std::to_string(i); ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; diff --git a/llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch b/llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch index f8106f0f..d45e4ec7 100644 --- a/llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch +++ b/llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch @@ -38,7 +38,7 @@ index 1c07e767a..0da3e065b 100644 #ifdef __cplusplus } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index d9f4d34f5..8a83427fb 100644 +index d43d46d1d..df79f9f79 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -74,6 +74,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); @@ -49,7 +49,7 @@ index d9f4d34f5..8a83427fb 100644 typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { VkStructureType sType; -@@ -13576,6 +13577,7 @@ struct ggml_backend_vk_device_context { +@@ -13669,6 +13670,7 @@ struct ggml_backend_vk_device_context { std::string pci_id; std::string id; std::string uuid; @@ -57,7 +57,7 @@ index d9f4d34f5..8a83427fb 100644 int major; int minor; int driver_major; -@@ -13594,6 +13596,20 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size +@@ -13687,6 +13689,20 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size vk::PhysicalDeviceProperties2 props2; vkdev.getProperties2(&props2); @@ -78,7 +78,7 @@ index d9f4d34f5..8a83427fb 100644 if (!is_integrated_gpu) { -@@ -13625,7 +13641,6 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size +@@ -13718,7 +13734,6 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size } // else fallback to memory budget if supported @@ -86,7 +86,7 @@ index d9f4d34f5..8a83427fb 100644 if (membudget_supported) { memprops.pNext = &budgetprops; } -@@ -14357,7 +14372,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -14452,7 +14467,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, /* .reg = */ reg, /* .context = */ ctx, }); @@ -94,7 +94,7 @@ index d9f4d34f5..8a83427fb 100644 // Gather additional information about the device int dev_idx = vk_instance.device_indices[i]; vk::PhysicalDeviceProperties props1; -@@ -14380,6 +14394,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -14475,6 +14489,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, } } ctx->uuid = oss.str(); diff --git a/llama/patches/0029-ggml-cuda-skip-large-batches.patch b/llama/patches/0029-ggml-cuda-skip-large-batches.patch index 86f1840c..a1005c58 100644 --- a/llama/patches/0029-ggml-cuda-skip-large-batches.patch +++ b/llama/patches/0029-ggml-cuda-skip-large-batches.patch @@ -10,10 +10,10 @@ fallback to cpu 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index cd71902df..d69d62193 100644 +index 48cdb1dcf..3102d7ea7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -4632,6 +4632,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g +@@ -4633,6 +4633,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { return false; } diff --git a/llama/patches/0030-win-exit-instead-of-abort.patch b/llama/patches/0030-win-exit-instead-of-abort.patch index 7dc156e4..4e4edcbd 100644 --- a/llama/patches/0030-win-exit-instead-of-abort.patch +++ b/llama/patches/0030-win-exit-instead-of-abort.patch @@ -8,7 +8,7 @@ Subject: [PATCH] win: exit instead of abort 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c -index 530ff7b95..fc0196eb7 100644 +index eb3ae72ea..c9242a15a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -250,8 +250,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { diff --git a/llama/patches/0031-fix-bakllava-regression.patch b/llama/patches/0031-fix-bakllava-regression.patch index fa306191..14ef26b5 100644 --- a/llama/patches/0031-fix-bakllava-regression.patch +++ b/llama/patches/0031-fix-bakllava-regression.patch @@ -9,10 +9,10 @@ Rever to prior logic of assuming an empty projector type is mlp 1 file changed, 4 insertions(+) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp -index 6be1470ad..2a325c726 100644 +index 84a3796b5..d3a37842d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp -@@ -2649,6 +2649,10 @@ struct clip_model_loader { +@@ -960,6 +960,10 @@ struct clip_model_loader { if (proj_type.empty()) { if (modality == CLIP_MODALITY_VISION) { get_string(KEY_VISION_PROJ_TYPE, proj_type, false); diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp index e04fb5a3..9ae5fd57 100644 --- a/llama/sampling_ext.cpp +++ b/llama/sampling_ext.cpp @@ -72,7 +72,7 @@ struct llama_vocab * llama_load_vocab_from_file(const char * fname) { try { const auto kv = LLM_KV(LLM_ARCH_UNKNOWN); std::vector splits = {}; - llama_model_loader ml(std::string(fname), splits, false, false, nullptr, nullptr); + llama_model_loader ml(std::string(fname), splits, false, false, false, nullptr, nullptr); vocab->load(ml, kv); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); diff --git a/llm/server.go b/llm/server.go index 49af4e1b..a89027b0 100644 --- a/llm/server.go +++ b/llm/server.go @@ -143,7 +143,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st var llamaModel *llama.Model var textProcessor model.TextProcessor var err error - if envconfig.NewEngine(true) || f.KV().OllamaEngineRequired() { + if envconfig.NewEngine() || f.KV().OllamaEngineRequired() { if len(projectors) == 0 { textProcessor, err = model.NewTextProcessor(modelPath) } else { diff --git a/ml/backend.go b/ml/backend.go index 1e781fa7..f287db6a 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -54,10 +54,6 @@ type CacheConfig struct { // MaskDType specifies the data type for generating the mask. If unset it will // default to DTypeF32. MaskDType DType - - // MaskBatchPadding specifies the multiple for the batch size dimension in the mask. - // Any position that does not correspond to an actual token will be filled with -Inf. - MaskBatchPadding int } // BackendParams controls how the backend loads and executes models diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index a50d8ec9..ebcc1d86 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -685,7 +685,7 @@ func (b *Backend) NewContextSize(n int) ml.Context { func (b *Backend) CacheConfig() ml.CacheConfig { if b.flashAttention == ml.FlashAttentionEnabled { - return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD} + return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16} } else { return ml.CacheConfig{CachePadding: 256, PermutedV: true} } @@ -1534,7 +1534,8 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase unsafe.SliceData(mropeSections), C.int(opts.Type), cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10), - C.float(ropeBase), C.float(ropeScale), + C.float(ropeBase), + C.float(ropeScale), C.float(opts.YaRN.ExtrapolationFactor), cmp.Or(C.float(opts.YaRN.AttentionFactor), 1), cmp.Or(C.float(opts.YaRN.BetaFast), 32), @@ -1546,9 +1547,11 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase dequant, positions.(*Tensor).t, opts.Factors.(*Tensor).t, - C.int(ropeDim), C.int(opts.Type), + C.int(ropeDim), + C.int(opts.Type), cmp.Or(C.int(opts.YaRN.OriginalContextLength), 128<<10), - C.float(ropeBase), C.float(ropeScale), + C.float(ropeBase), + C.float(ropeScale), C.float(opts.YaRN.ExtrapolationFactor), cmp.Or(C.float(opts.YaRN.AttentionFactor), 1), cmp.Or(C.float(opts.YaRN.BetaFast), 32), @@ -1657,11 +1660,6 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin } if mask != nil { - padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1) - if padSize > 0 { - mask = mask.Pad(ctx, 0, padSize, 0, 0) - } - if mask.DType() != cacheConfig.MaskDType { mask = mask.Cast(ctx, cacheConfig.MaskDType) } diff --git a/ml/backend/ggml/ggml/include/ggml-alloc.h b/ml/backend/ggml/ggml/include/ggml-alloc.h index 7ab3f019..7fa8403b 100644 --- a/ml/backend/ggml/ggml/include/ggml-alloc.h +++ b/ml/backend/ggml/ggml/include/ggml-alloc.h @@ -53,7 +53,14 @@ GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc); // call with a worst-case graph to avoid buffer reallocations // not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed // returns false if the buffer allocation failed +// ggml_gallocr_resrve_n_size writes the buffer sizes per galloc buffer that would be allocated by ggml_gallocr_reserve_n to sizes GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph); +GGML_API void ggml_gallocr_reserve_n_size( + ggml_gallocr_t galloc, + struct ggml_cgraph * graph, + const int * node_buffer_ids, + const int * leaf_buffer_ids, + size_t * sizes); GGML_API bool ggml_gallocr_reserve_n( ggml_gallocr_t galloc, struct ggml_cgraph * graph, @@ -69,6 +76,8 @@ GGML_API size_t ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, in // Utils // Create a buffer and allocate all the tensors in a ggml_context +// ggml_backend_alloc_ctx_tensors_from_buft_size returns the size of the buffer that would be allocated by ggml_backend_alloc_ctx_tensors_from_buft +GGML_API size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend); diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h index 6510e0cb..6ad583f0 100644 --- a/ml/backend/ggml/ggml/include/ggml-backend.h +++ b/ml/backend/ggml/ggml/include/ggml-backend.h @@ -319,6 +319,7 @@ extern "C" { GGML_API void ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size); // Initialize backend buffers from a measure graph + GGML_API void ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes); GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched); diff --git a/ml/backend/ggml/ggml/include/ggml-cpu.h b/ml/backend/ggml/ggml/include/ggml-cpu.h index 9edd4851..4f3b99c8 100644 --- a/ml/backend/ggml/ggml/include/ggml-cpu.h +++ b/ml/backend/ggml/ggml/include/ggml-cpu.h @@ -99,6 +99,7 @@ extern "C" { GGML_BACKEND_API int ggml_cpu_has_sme (void); // other GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); + GGML_BACKEND_API int ggml_cpu_get_rvv_vlen (void); // risc-v vector length in bytes GGML_BACKEND_API int ggml_cpu_has_vsx (void); GGML_BACKEND_API int ggml_cpu_has_vxe (void); GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void); diff --git a/ml/backend/ggml/ggml/include/ggml-zendnn.h b/ml/backend/ggml/ggml/include/ggml-zendnn.h new file mode 100644 index 00000000..a30a3a98 --- /dev/null +++ b/ml/backend/ggml/ggml/include/ggml-zendnn.h @@ -0,0 +1,22 @@ +#pragma once + +#include "ggml-backend.h" +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_zendnn_init(void); + +GGML_BACKEND_API bool ggml_backend_is_zendnn(ggml_backend_t backend); + +// number of threads used for zendnn operations +GGML_BACKEND_API void ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zendnn_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/ml/backend/ggml/ggml/include/ggml.h b/ml/backend/ggml/ggml/include/ggml.h index 6bc762c0..20c912d0 100644 --- a/ml/backend/ggml/ggml/include/ggml.h +++ b/ml/backend/ggml/ggml/include/ggml.h @@ -2305,13 +2305,11 @@ extern "C" { float stop, float step); -#define GGML_KQ_MASK_PAD 1 - - // q: [n_embd_k, n_batch, n_head, ne3 ] - // k: [n_embd_k, n_kv, n_head_kv, ne3 ] - // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !! - // mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! - // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !! + // q: [n_embd_k, n_batch, n_head, ne3 ] + // k: [n_embd_k, n_kv, n_head_kv, ne3 ] + // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !! + // mask: [n_kv, n_batch, ne32, ne33] + // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !! // // broadcast: // n_head % n_head_kv == 0 @@ -2617,7 +2615,8 @@ extern "C" { // Set callback for all future logging events. // If this is not called, or NULL is supplied, everything is output on stderr. - GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data); + GGML_API void ggml_log_get(ggml_log_callback * log_callback, void ** user_data); + GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data); GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); diff --git a/ml/backend/ggml/ggml/src/ggml-alloc.c b/ml/backend/ggml/ggml/src/ggml-alloc.c index dbfd8b5b..73b39bfe 100644 --- a/ml/backend/ggml/ggml/src/ggml-alloc.c +++ b/ml/backend/ggml/ggml/src/ggml-alloc.c @@ -312,16 +312,9 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al } // this is a very naive implementation, but for our case the number of free blocks should be very small -static void ggml_dyn_tallocr_free_tensor(struct ggml_dyn_tallocr * alloc, struct buffer_address addr, size_t size, const struct ggml_tensor * tensor) { +static void ggml_dyn_tallocr_free_bytes(struct ggml_dyn_tallocr * alloc, struct buffer_address addr, size_t size) { size = aligned_offset(NULL, size, alloc->alignment); - AT_PRINTF("%s: freeing %s at {chunk=%d, offset=%zu} (%zu bytes) - n_free_blocks = %d\n", - __func__, tensor->name, addr.chunk, addr.offset, size, alloc->chunks[addr.chunk]->n_free_blocks); - -#ifdef GGML_ALLOCATOR_DEBUG - remove_allocated_tensor(alloc, addr, tensor); -#endif - struct tallocr_chunk * chunk = alloc->chunks[addr.chunk]; // see if we can merge with an existing block @@ -357,8 +350,6 @@ static void ggml_dyn_tallocr_free_tensor(struct ggml_dyn_tallocr * alloc, struct } // otherwise, add a new block ggml_dyn_tallocr_insert_block(chunk, addr.offset, size); - - GGML_UNUSED(tensor); } static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) { @@ -608,7 +599,9 @@ static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) { } static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) { - return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated; + return t->data != NULL // tensor data already set externally + || t->buffer // tensor on external buffer (but not yet allocated) + || ggml_gallocr_is_own(galloc, t); // tensor will be allocated by galloc } // free the extra space at the end if the new tensor is smaller @@ -621,13 +614,17 @@ static void ggml_gallocr_free_extra_space(ggml_gallocr_t galloc, struct ggml_ten GGML_ASSERT(parent_size >= node_size); + // note: we want after the freeing the chunks to continue to be aligned + struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id]; + parent_size = aligned_offset(NULL, parent_size, p_alloc->alignment); + node_size = aligned_offset(NULL, node_size, p_alloc->alignment); + if (parent_size > node_size) { - struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id]; struct buffer_address p_addr = p_hn->addr; p_addr.offset += node_size; size_t extra_size = parent_size - node_size; AT_PRINTF("freeing extra %zu bytes from parent %s for %s\n", extra_size, parent->name, node->name); - ggml_dyn_tallocr_free_tensor(p_alloc, p_addr, extra_size, parent); + ggml_dyn_tallocr_free_bytes(p_alloc, p_addr, extra_size); } } @@ -711,7 +708,14 @@ static void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * n struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id]; ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id]; size_t size = ggml_backend_buft_get_alloc_size(buft, node); - ggml_dyn_tallocr_free_tensor(alloc, hn->addr, size, node); + + AT_PRINTF("%s: freeing %s at {chunk=%d, offset=%zu} (%zu bytes) - n_free_blocks = %d\n", + __func__, node->name, hn->addr.chunk, hn->addr.offset, size, alloc->chunks[hn->addr.chunk]->n_free_blocks); +#ifdef GGML_ALLOCATOR_DEBUG + remove_allocated_tensor(alloc, hn->addr, node); +#endif + + ggml_dyn_tallocr_free_bytes(alloc, hn->addr, size); hn->allocated = false; } @@ -826,7 +830,8 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr } } -bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) { +static bool ggml_gallocr_reserve_n_impl( + ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids, bool no_alloc) { size_t min_hash_size = graph->n_nodes + graph->n_leafs; // add 25% margin to avoid hash collisions min_hash_size += min_hash_size / 4; @@ -933,19 +938,22 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; if (cur_size > 0) { GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", - __func__, ggml_backend_buft_name(galloc->bufts[i]), - cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); } } #endif ggml_vbuffer_free(galloc->buffers[i]); - galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); - if (galloc->buffers[i]) { - galloc->buffer_sizes[i] = ggml_vbuffer_size(galloc->buffers[i]); + if (no_alloc) { + galloc->buffers[i] = NULL; } else { - GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); - galloc->buffer_sizes[i] = new_size; - success = false; + galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); + if (galloc->buffers[i]) { + galloc->buffer_sizes[i] = ggml_vbuffer_size(galloc->buffers[i]); + } else { + GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); + galloc->buffer_sizes[i] = new_size; + success = false; + } } } else { galloc->buffer_sizes[i] = ggml_vbuffer_size(galloc->buffers[i]); @@ -955,6 +963,21 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c return success; } +void ggml_gallocr_reserve_n_size( + ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids, size_t * sizes) { + GGML_ASSERT(ggml_gallocr_reserve_n_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids, /*no_alloc =*/ true)); + for (int i = 0; i < galloc->n_buffers; i++) { + sizes[i] = 0; + for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) { + sizes[i] += galloc->buf_tallocs[i]->chunks[c]->max_size; + } + } +} + +bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) { + return ggml_gallocr_reserve_n_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids, /*no_alloc =*/ false); +} + bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { return ggml_gallocr_reserve_n(galloc, graph, NULL, NULL); } @@ -1173,7 +1196,8 @@ static bool alloc_tensor_range(struct ggml_context * ctx, return true; } -ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { +static ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft_impl( + struct ggml_context * ctx, ggml_backend_buffer_type_t buft, size_t * nbytes_total, bool no_alloc) { GGML_ASSERT(ggml_get_no_alloc(ctx) == true); size_t alignment = ggml_backend_buft_get_alignment(buft); @@ -1181,6 +1205,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte ggml_backend_buffer_t * buffers = NULL; size_t n_buffers = 0; + *nbytes_total = 0; size_t cur_buf_size = 0; struct ggml_tensor * first = ggml_get_first_tensor(ctx); @@ -1192,10 +1217,11 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte if (cur_buf_size > 0 && (cur_buf_size + this_size) > max_size) { // allocate tensors in the current buffer - if (!alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) { + if (!no_alloc && !alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) { return NULL; } first = t; + *nbytes_total += cur_buf_size; cur_buf_size = this_size; } else { cur_buf_size += this_size; @@ -1204,15 +1230,21 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte // allocate remaining tensors if (cur_buf_size > 0) { - if (!alloc_tensor_range(ctx, first, NULL, buft, cur_buf_size, &buffers, &n_buffers)) { + *nbytes_total += cur_buf_size; + if (!no_alloc && !alloc_tensor_range(ctx, first, NULL, buft, cur_buf_size, &buffers, &n_buffers)) { return NULL; } } + if (no_alloc) { + return NULL; + } + if (n_buffers == 0) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: all tensors in the context are already allocated\n", __func__); #endif + GGML_ASSERT(!buffers); return NULL; } @@ -1222,10 +1254,24 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte } else { buffer = ggml_backend_multi_buffer_alloc_buffer(buffers, n_buffers); } - free(buffers); + if (buffers) { + free(buffers); // can be NULL if context is empty or no_alloc + } return buffer; } +size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { + size_t nbytes_total = 0; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc=*/ true); + GGML_ASSERT(!buf); + return nbytes_total; +} + +ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { + size_t nbytes_total = 0; + return ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc =*/ false); +} + ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) { return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend)); } diff --git a/ml/backend/ggml/ggml/src/ggml-backend.cpp b/ml/backend/ggml/ggml/src/ggml-backend.cpp index a1a19fe5..189e9717 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend.cpp @@ -147,6 +147,12 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { return (void *)ggml_backend_buffer_get_alignment(buffer); } + // FIXME JG: a multi_buffer has a non-zero size, according to the above comment get_base is not optional, + // I don't know whether the above comment is correct + if (!buffer->iface.get_base) { + return NULL; + } + void * base = buffer->iface.get_base(buffer); GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL"); @@ -1786,6 +1792,20 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) { sched->is_alloc = false; } +void ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes) { + GGML_ASSERT(sched); + GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); + GGML_ASSERT(sizes); + + ggml_backend_sched_reset(sched); + + ggml_backend_sched_synchronize(sched); + + ggml_backend_sched_split_graph(sched, measure_graph); + + ggml_gallocr_reserve_n_size(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids, sizes); +} + bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { GGML_ASSERT(sched); GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp index 683ed8d2..fb7f074a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -24,6 +24,7 @@ #define UNUSED GGML_UNUSED +#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD)) static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) { @@ -46,6 +47,7 @@ static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in, scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4); memcpy(out_scales, scales_u32, 8); } +#endif void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c index 47089a62..8d485131 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c @@ -83,6 +83,11 @@ struct ggml_arm_arch_features_type { } ggml_arm_arch_features = { 0 }; #endif +#if defined(__riscv) +struct ggml_riscv_arch_features_type { + int rvv_vlen; +} ggml_riscv_arch_features = { 0 }; +#endif #if defined(_WIN32) @@ -189,6 +194,9 @@ typedef void * thread_ret_t; typedef pthread_t ggml_thread_t; +#define GGML_THREADPOOL_N_THREADS_MASK (0xffffU) +#define GGML_THREADPOOL_N_THREADS_BITS (16) + #if defined(__APPLE__) #include #include @@ -451,7 +459,7 @@ struct ggml_threadpool { struct ggml_cplan * cplan; // synchronization primitives - atomic_int n_graph; // incremented when there is work to be done (i.e each graph) + atomic_int n_graph; // updated when there is work to be done (i.e each graph) holds graph and active thread counts. atomic_int GGML_CACHE_ALIGN n_barrier; atomic_int GGML_CACHE_ALIGN n_barrier_passed; atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. @@ -459,12 +467,10 @@ struct ggml_threadpool { // these are atomic as an annotation for thread-sanitizer atomic_bool stop; // Used for stopping the threadpool altogether atomic_bool pause; // Used for pausing the threadpool or individual threads - atomic_int abort; // Used for aborting processing of a graph + atomic_int abort; // Used for aborting processing of a graph struct ggml_compute_state * workers; // per thread state - int n_threads_max; // number of threads in the pool - atomic_int n_threads_cur; // number of threads used in the current graph - + int n_threads; // Number of threads in the pool int32_t prio; // Scheduling priority uint32_t poll; // Polling level (0 - no polling) @@ -541,7 +547,7 @@ struct ggml_state { static struct ggml_state g_state = {0}; void ggml_barrier(struct ggml_threadpool * tp) { - int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed); + int n_threads = atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK; if (n_threads == 1) { return; } @@ -558,7 +564,7 @@ void ggml_barrier(struct ggml_threadpool * tp) { // last thread atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed); - // exit barrier (fill seq-cst fence) + // exit barrier (full seq-cst fence) atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst); return; } @@ -704,6 +710,15 @@ static void ggml_init_arm_arch_features(void) {} #endif #endif // __ARM_ARCH +#if defined(__riscv) && defined(__riscv_v_intrinsic) +#include +static void ggml_init_riscv_arch_features(void) { + ggml_riscv_arch_features.rvv_vlen = __riscv_vlenb(); +} +#else +static void ggml_init_riscv_arch_features(void) {} +#endif + struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { GGML_ASSERT(!ggml_get_no_alloc(ctx)); @@ -2630,7 +2645,7 @@ static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask void ggml_threadpool_free(struct ggml_threadpool* threadpool) { if (!threadpool) return; - const int n_threads = threadpool->n_threads_max; + const int n_threads = threadpool->n_threads; #ifndef GGML_USE_OPENMP struct ggml_compute_state* workers = threadpool->workers; @@ -2706,7 +2721,7 @@ struct ggml_cplan ggml_graph_plan( //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); } if (n_threads <= 0) { - n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS; + n_threads = threadpool ? threadpool->n_threads : GGML_DEFAULT_N_THREADS; } #if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__) @@ -2914,12 +2929,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_compute_params params = { /*.ith =*/ state->ith, - /*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed), + /*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK, /*.wsize =*/ cplan->work_size, /*.wdata =*/ cplan->work_data, /*.threadpool=*/ tp, }; + GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); + for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) { struct ggml_tensor * node = cgraph->nodes[node_n]; @@ -2945,6 +2962,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } } + GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); + ggml_barrier(state->threadpool); return 0; @@ -2952,27 +2971,23 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { #ifndef GGML_USE_OPENMP -// check if thread is active -static inline bool ggml_graph_compute_thread_active(struct ggml_compute_state * state) { - struct ggml_threadpool * threadpool = state->threadpool; - int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed); - return (state->ith < n_threads); -} - // check if thread is ready to proceed (exit from polling or sleeping) +// returns true if loops should exit, sets state->pending to indicate new work static inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * state) { struct ggml_threadpool * threadpool = state->threadpool; if (state->pending || threadpool->stop || threadpool->pause) { return true; } // check for new graph/work - int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed); - if (new_graph != state->last_graph) { - state->pending = ggml_graph_compute_thread_active(state); - state->last_graph = new_graph; + int n_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed); + int n_threads = n_graph & GGML_THREADPOOL_N_THREADS_MASK; + if (n_graph != state->last_graph) { + state->pending = (state->ith < n_threads); + state->last_graph = n_graph; + return true; } - return state->pending; + return false; } // sync thread state after polling @@ -2989,11 +3004,6 @@ static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * st static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) { struct ggml_threadpool * threadpool = state->threadpool; - // Skip polling for unused threads - if (!ggml_graph_compute_thread_active(state)) { - return state->pending; - } - // This seems to make 0 ... 100 a decent range for polling level across modern processors. // Perhaps, we can adjust it dynamically based on load and things. const uint64_t n_rounds = 1024UL * 128 * threadpool->poll; @@ -3055,7 +3065,6 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) { ggml_graph_compute_check_for_work(state); if (state->pending) { state->pending = false; - ggml_graph_compute_thread(state); } } @@ -3070,14 +3079,15 @@ static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool, int ggml_mutex_lock(&threadpool->mutex); - GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads); + // Update the number of active threads and the graph count + int n_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed) >> GGML_THREADPOOL_N_THREADS_BITS; + n_graph = ((n_graph + 1) << GGML_THREADPOOL_N_THREADS_BITS) | (n_threads & GGML_THREADPOOL_N_THREADS_MASK); - // Update the number of active threads - atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); + GGML_PRINT_DEBUG("compute-kickoff: n_threads %d n_graph %d\n", n_threads, n_graph); // Indicate the graph is ready to be processed // We need the full seq-cst fence here because of the polling threads (used in thread_sync) - atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst); + atomic_store_explicit(&threadpool->n_graph, n_graph, memory_order_seq_cst); if (threadpool->pause) { // Update main thread prio and affinity to match the threadpool settings @@ -3115,8 +3125,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl( threadpool->pause = tpp->paused; threadpool->abort = -1; threadpool->workers = NULL; - threadpool->n_threads_max = tpp->n_threads; - threadpool->n_threads_cur = tpp->n_threads; + threadpool->n_threads = tpp->n_threads; threadpool->poll = tpp->poll; threadpool->prio = tpp->prio; threadpool->ec = GGML_STATUS_SUCCESS; @@ -3211,7 +3220,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl { // update the number of threads from the actual number of threads that we got from OpenMP n_threads = omp_get_num_threads(); - atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); + atomic_store_explicit(&threadpool->n_graph, n_threads, memory_order_relaxed); } // Apply thread CPU mask and priority @@ -3224,13 +3233,13 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl ggml_graph_compute_thread(&threadpool->workers[ith]); } } else { - atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed); + atomic_store_explicit(&threadpool->n_graph, 1, memory_order_relaxed); ggml_graph_compute_thread(&threadpool->workers[0]); } #else - if (n_threads > threadpool->n_threads_max) { - GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max); - n_threads = threadpool->n_threads_max; + if (n_threads > threadpool->n_threads) { + GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads); + n_threads = threadpool->n_threads; } // Kick all threads to start the new graph @@ -3470,6 +3479,14 @@ int ggml_cpu_has_riscv_v(void) { #endif } +int ggml_cpu_get_rvv_vlen(void) { +#if defined(__riscv) && defined(__riscv_v_intrinsic) + return ggml_riscv_arch_features.rvv_vlen; +#else + return 0; +#endif +} + int ggml_cpu_has_f16c(void) { #if defined(__F16C__) return 1; @@ -3636,6 +3653,10 @@ void ggml_cpu_init(void) { ggml_init_arm_arch_features(); #endif +#if defined(__riscv) + ggml_init_riscv_arch_features(); +#endif + is_first_call = false; } diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp index 32f14c81..92ba577a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -585,6 +585,10 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r if (ggml_cpu_has_riscv_v()) { features.push_back({ "RISCV_V", "1" }); } + if (ggml_cpu_get_rvv_vlen() > 0) { + static std::string rvv_vlen = std::to_string(ggml_cpu_get_rvv_vlen()); + features.push_back({ "RVV_VLEN", rvv_vlen.c_str() }); + } if (ggml_cpu_has_vsx()) { features.push_back({ "VSX", "1" }); } diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h new file mode 100644 index 00000000..a7078687 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h @@ -0,0 +1,333 @@ +#pragma once + +typedef vector unsigned char vec_t; +typedef __vector_quad acc_t; + +template +class tinyBLAS_Q0_PPC { + public: + tinyBLAS_Q0_PPC(int64_t k, + const TA *A, int64_t lda, + const block_q8_0 *B, int64_t ldb, + float *C, int64_t ldc, + int ith, int nth); + + void matmul(int64_t m, int64_t n); + void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) { + vec_t A_pack[mc*kc*2]; + vec_t B_pack[nc*kc*2]; + int comparray[mc*kc]; + constexpr bool is_Ablock_q4 = std::is_same_v; + int64_t ytiles = m / mc; + int64_t xtiles = n / nc; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) { + end = tiles; + } + for (int64_t job = start; job < end; ++job) { + int64_t ii = (job / xtiles) * mc; + int64_t jj = (job % xtiles) * nc; + for (int64_t kk = 0; kk < k; kk += kc) { + if constexpr(is_Ablock_q4) { + packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray); + } else { + packNormal_large(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray); + } + packNormal_large(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true); + KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray); + } + } + } + + private: + inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); + } + } + } + + inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I); + *c_ptr += *((float*)&fin_res[idx+I]+J); + } + } + } + + template + inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) { + vector signed int vec_C[4]; + vector float CA[4] = {0}; + vector float res[4] = {0}; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int i = 0; i < 4; i++) { + CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0)); + res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); + fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); + } + } + + inline void process_q4_elements(vector signed char (&c)[2], int* ca) { + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + vector signed int vsum = {0}; + vector signed int vsum2 = {0}; + c[0] = vec_and(c[1], lowMask); + c[1] = vec_sr(c[1], v4); + c[0] = vec_sub(c[0], v8); + c[1] = vec_sub(c[1], v8); + vsum = vec_sum4s(c[0], vsum); + vsum2 = vec_sum4s(c[1], vsum2); + vsum = vec_add(vsum, vsum2); + *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + } + + template + inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) { + vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + V2 t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char xor_vector; + uint8_t flip_vec = 0x80; + xor_vector = vec_splats(flip_vec); + t1 = vec_perm(s1, s2, swiz1); + t2 = vec_perm(s1, s2, swiz2); + t3 = vec_perm(s3, s4, swiz1); + t4 = vec_perm(s3, s4, swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + } + + template + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii,jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii,jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii,jj); + } else { + assert(false && "RN/RM values not supported"); + } + } + template + void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray); + template + void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip); + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n); + void KERNEL_4x8(int64_t ii, int64_t jj); + void KERNEL_8x4(int64_t ii, int64_t jj); + void KERNEL_8x8(int64_t ii, int64_t jj); + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN); + template + void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n); + + void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){ + for (int I = 0; I<8; I++) { + float a_scale = unhalf((A+((ii+I)*lda)+blk)->d); + for (int J = 0; J<4; J++) { + *((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d)); + *((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d)); + } + } + } + + inline void process_q8_elements(const int8_t *qs, int *ca) { + vector signed char c1 = vec_xl(0, qs); + vector signed char c2 = vec_xl(16, qs); + vector signed int vsum1 = {0}; + vector signed int vsum2 = {0}; + vsum1 = vec_sum4s(c1, vsum1); + vsum2 = vec_sum4s(c2, vsum2); + vector signed int vsum = vec_add(vsum1, vsum2); + *ca = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + } + + template + void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) { + int64_t i, j; + block_q8_0 *aoffset = NULL; + VA *vecOffset = NULL; + block_q8_0* aoffsets[8]; + __vector_pair arr[8]; + VB c[8][2] = {0}; + VB c1[8] = {0}; VB c2[8] = {0}; + aoffset = const_cast(a); + vecOffset = vec; + j = (rows >> 3); + int index = 0; + if (j > 0) { + do { + for (int it = 0; it < 8; it++) + aoffsets[it] = aoffset + it*lda; + aoffset += 8 * lda; + for (int blk = 0; blk < kc; blk++) { + for (int it = 0; it < 8; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; + if (comparray){ + process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]); + } + } + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); + vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); + vecOffset += 256; + } + j--; + index += 8*kc; + } while(j > 0); + } + + } + + void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) { + int64_t i, j; + TA *aoffset = NULL; + int8_t *vecOffset = NULL; + TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; + vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; + aoffset = const_cast(a); + vecOffset = vec; + int index = 0; + j = (rows >> 3); + if (j > 0) { + do { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset5 = aoffset4 + lda; + aoffset6 = aoffset5 + lda; + aoffset7 = aoffset6 + lda; + aoffset8 = aoffset7 + lda; + aoffset += 8 * lda; + for (int blk = 0; blk < kc; blk++) { + c1[1] = reinterpret_cast(vec_xl(0, (aoffset1+blk)->qs)); + c2[1] = reinterpret_cast(vec_xl(0, (aoffset2+blk)->qs)); + c3[1] = reinterpret_cast(vec_xl(0, (aoffset3+blk)->qs)); + c4[1] = reinterpret_cast(vec_xl(0, (aoffset4+blk)->qs)); + c5[1] = reinterpret_cast(vec_xl(0, (aoffset5+blk)->qs)); + c6[1] = reinterpret_cast(vec_xl(0, (aoffset6+blk)->qs)); + c7[1] = reinterpret_cast(vec_xl(0, (aoffset7+blk)->qs)); + c8[1] = reinterpret_cast(vec_xl(0, (aoffset8+blk)->qs)); + + process_q4_elements(c1, &comparray[index + 8*blk+0]); + process_q4_elements(c2, &comparray[index + 8*blk+1]); + process_q4_elements(c3, &comparray[index + 8*blk+2]); + process_q4_elements(c4, &comparray[index + 8*blk+3]); + process_q4_elements(c5, &comparray[index + 8*blk+4]); + process_q4_elements(c6, &comparray[index + 8*blk+5]); + process_q4_elements(c7, &comparray[index + 8*blk+6]); + process_q4_elements(c8, &comparray[index + 8*blk+7]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); + vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); + vecOffset += 256; + } + j--; + index += 8*kc; + } while (j > 0); + } + } + + void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) { + acc_t acc[8]; + for (int i = 0; i < mc ; i += 8) { + for (int j = 0; j < nc; j += 8) { + vector float fin_res[16] = {0}; + vector float vs[16] = {0}; + for (int64_t kk = 0; kk < kc; kk+=2) { + for (int x = 0; x < 8; x++) { + __builtin_mma_xxsetaccz(&acc[x]); + } + int A_block_idx = (i/8)*(16*kc) + kk*16; + int B_block_idx = (j/8)*(16*kc)+ kk*16; + vec_t *A_block = &vec_A[A_block_idx]; + vec_t *B_block = &vec_B[B_block_idx]; + for (int x = 0; x < 8; x++) { + __builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]); + __builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]); + __builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]); + __builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]); + } + compute_scale(ii+i, jj+j, l+kk, vs); + int c_index = (i/8)*(8*kc)+ kk*8; + int* c_block = &comparray[c_index]; + compute(&acc[0], 0, 0, c_block, vs, fin_res); + compute(&acc[1], 4, 4, c_block, vs, fin_res); + compute(&acc[2], 0, 8, c_block, vs, fin_res); + compute(&acc[3], 4, 12, c_block, vs, fin_res); + + A_block_idx = (i/8)*(16*kc) + (kk+1)*16; + B_block_idx = (j/8)*(16*kc)+ (kk+1)*16; + A_block = &vec_A[A_block_idx]; + B_block = &vec_B[B_block_idx]; + for (int x = 0; x < 8; x++) { + __builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]); + __builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]); + __builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]); + __builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]); + } + compute_scale(ii+i, jj+j, l+kk+1, vs); + c_index = (i/8)*(8*kc)+ (kk+1)*8; + c_block = &comparray[c_index]; + compute(&acc[4], 0, 0, c_block, vs, fin_res); + compute(&acc[5], 4, 4, c_block, vs, fin_res); + compute(&acc[6], 0, 8, c_block, vs, fin_res); + compute(&acc[7], 4, 12, c_block, vs, fin_res); + + } + if (l == 0) { + save_res(ii+i, jj+j, 0, fin_res); + save_res(ii+i+4, jj+j, 4, fin_res); + save_res(ii+i, jj+j+4, 8, fin_res); + save_res(ii+i+4, jj+j+4, 12, fin_res); + } else { + add_save_res(ii+i, jj+j, 0, fin_res); + add_save_res(ii+i+4, jj+j, 4, fin_res); + add_save_res(ii+i, jj+j+4, 8, fin_res); + add_save_res(ii+i+4, jj+j+4, 12, fin_res); + } + } + } + } + + const TA *const A; + const block_q8_0 *const B; + float *C; + const int64_t k; + int64_t kc; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp index 9f0d449b..b70ea7d7 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp @@ -2169,7 +2169,8 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0; if (cur->type == GGML_TYPE_Q4_0) { - if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { + if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) + || (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) { if (cur->ne[1] % 8 == 0) { return &q4_0_8x8_q8_0; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh index 8b0fb5d4..e800ee8f 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh @@ -102,19 +102,22 @@ static cudaError_t cudaMemsetAsyncReserve ( void* devPtr, int value, size_t coun #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_RDNA3_5 (GGML_CUDA_CC_OFFSET_AMD + 0x1150) // AI 370, AI Max 395 laptops. #define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000 -#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD) -#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) -#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) -#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) -#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4) -#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) -#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) -#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) -#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2) -#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3) -#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD) +#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) +#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) +#define GGML_CUDA_CC_IS_RDNA3_0(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA3_5) +#define GGML_CUDA_CC_IS_RDNA3_5(cc) (cc >= GGML_CUDA_CC_RDNA3_5 && cc < GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_RDNA3(cc) (GGML_CUDA_CC_IS_RDNA3_0(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc)) +#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) +#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2) +#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads #define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh index 2750117a..8dc82a9d 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh @@ -642,8 +642,8 @@ static __global__ void flash_attn_stream_k_fixup( const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; - const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -679,7 +679,7 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh index d51537f7..7bd1044c 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1380,8 +1380,8 @@ static __global__ void flash_attn_ext_f16( const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; // kbc == k block continuous, current index in continuous ijk space. - int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; - const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -1401,7 +1401,7 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half * mask_h = ncols2 == 1 && !mask ? nullptr : - (const half *) (mask + nb33*(sequence % ne33)); + (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu index d69d6219..3102d7ea 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4588,6 +4588,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_EXPM1: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_XIELU: case GGML_UNARY_OP_FLOOR: case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: @@ -4907,9 +4908,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CUMSUM: case GGML_OP_TRI: case GGML_OP_DIAG: - return true; case GGML_OP_SOLVE_TRI: - return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32; + return true; + default: return false; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh index 0b13293d..dcfa40f4 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh @@ -189,6 +189,9 @@ namespace ggml_cuda_mma { return 8 * (threadIdx.x / 16) + l; #elif defined(RDNA3) return 2 * l + (threadIdx.x / 16); +#else + NO_DEVICE_CODE; + return -1; #endif // defined(RDNA4) } else { NO_DEVICE_CODE; @@ -290,8 +293,12 @@ namespace ggml_cuda_mma { } } #elif defined(AMD_WMMA_AVAILABLE) - +#if defined(RDNA3) + // RDNA3 has duplicated data as input. + static constexpr int ne = I * J / 32 * 2; +#else static constexpr int ne = I * J / 32; +#endif // defined(RDNA3) half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { @@ -310,7 +317,14 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { +#if defined(RDNA4) return 4 * (threadIdx.x / 16) + l; +#elif defined(RDNA3) + return l; +#else + NO_DEVICE_CODE; + return -1; +#endif // defined(RDNA4) } else { NO_DEVICE_CODE; return -1; @@ -366,11 +380,16 @@ namespace ggml_cuda_mma { static constexpr int I = I_; static constexpr int J = J_; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; - static constexpr int ne = I * J / WARP_SIZE; - - nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; #if defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA3) + // RDNA3 has duplicated data as input. + static constexpr int ne = I * J / 32 * 2; +#else + static constexpr int ne = I * J / 32; +#endif // defined(RDNA3) + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + static constexpr __device__ bool supported() { if (I == 16 && J == 8) return true; return false; @@ -387,13 +406,23 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { +#if defined(RDNA4) return 4 * (threadIdx.x / 16) + l; +#elif defined(RDNA3) + return l; +#else + NO_DEVICE_CODE; + return -1; +#endif // defined(RDNA4) } else { NO_DEVICE_CODE; return -1; } } #else + static constexpr int ne = I * J / WARP_SIZE; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + static constexpr __device__ bool supported() { if (I == 8 && J == 8) return true; if (I == 16 && J == 4) return true; @@ -546,8 +575,14 @@ namespace ggml_cuda_mma { } #elif defined(AMD_WMMA_AVAILABLE) if constexpr (std::is_same_v || std::is_same_v) { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - +#if defined(RDNA4) + ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); +#elif defined(RDNA3) + ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); + ggml_cuda_memcpy_1(t.x + t.ne/2, xs0 + t.get_i(0) * stride + t.get_j(t.ne/2)); +#else + NO_DEVICE_CODE; +#endif // defined(RDNA4) } else if constexpr (std::is_same_v) { if constexpr (I == 16 && J == 4) { int64_t * xi = (int64_t *) t.x; @@ -888,6 +923,16 @@ namespace ggml_cuda_mma { const halfx8_t& a_frag = reinterpret_cast(A.x[0]); const halfx8_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); +#elif defined(RDNA3) + using halfx16_t = __attribute__((ext_vector_type(16))) _Float16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast(D.x[0]); + const halfx16_t& a_frag = reinterpret_cast(A.x[0]); + const halfx16_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; #endif // RDNA4 #else GGML_UNUSED_VARS(D, A, B); @@ -905,6 +950,16 @@ namespace ggml_cuda_mma { const bf16x8_t& a_frag = reinterpret_cast(A.x[0]); const bf16x8_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag); +#elif defined(RDNA3) + using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast(D.x[0]); + const bf16x16_t& a_frag = reinterpret_cast(A.x[0]); + const bf16x16_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; #endif // RDNA4 #else GGML_UNUSED_VARS(D, A, B); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu index 7cf33f0d..6643f243 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu @@ -151,7 +151,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } } else { - if (src1_ncols > 16) { + if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) { + return false; + } else if (src1_ncols > 16) { return false; } } @@ -160,9 +162,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const case GGML_TYPE_F32: return ampere_mma_available(cc); case GGML_TYPE_F16: - return volta_mma_available(cc) || turing_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc)); + return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc); case GGML_TYPE_BF16: - return ampere_mma_available(cc) || (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc)); + return ampere_mma_available(cc) || amd_wmma_available(cc); default: return false; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu index 6238ce7e..32948e4d 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu @@ -765,7 +765,10 @@ bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0 return ne11 <= 8; } else if (GGML_CUDA_CC_IS_AMD(cc)) { if (fp16_mma_hardware_available(cc)) { - if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + return ne11 <= 3; + } + if (GGML_CUDA_CC_IS_RDNA4(cc)) { return ne11 <= 5; } return ne11 <= 2; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/solve_tri.cu b/ml/backend/ggml/ggml/src/ggml-cuda/solve_tri.cu index e161d4dc..177ffc26 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/solve_tri.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/solve_tri.cu @@ -3,6 +3,80 @@ #include "solve_tri.cuh" #define MAX_N_FAST 64 +#define MAX_K_FAST 32 + +static __global__ void get_batch_pointers(const float * A, + float * X, + const float ** A_ptrs, + float ** X_ptrs, + int64_t ne02, + int64_t total_batches, + size_t s02, + size_t s03, + size_t s2, + size_t s3) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_batches) { + return; + } + + const int64_t i3 = idx / ne02; + const int64_t i2 = idx % ne02; + + A_ptrs[idx] = A + i3 * s03 + i2 * s02; + X_ptrs[idx] = X + i3 * s3 + i2 * s2; +} + +static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx, + const float * A, + const float * B, + float * X, + int n, + int k, + int64_t ne02, + int64_t ne03, + size_t s02, + size_t s03, + size_t s12, + size_t s13, + size_t s2, + size_t s3, + cudaStream_t stream) { + const float alpha = 1.0f; + const int64_t total_batches = ne02 * ne03; + if (total_batches == 0) { + return; + } + + // Bulk copy B -> X (contiguous tensors) + if (X != B) { + const int64_t total_elements_BX = n * k * total_batches; + CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + } + + const int id = ggml_cuda_get_device(); + + ggml_cuda_pool_alloc A_ptrs_alloc(ctx.pool(id), total_batches); + ggml_cuda_pool_alloc X_ptrs_alloc(ctx.pool(id), total_batches); + + const float ** A_ptrs_dev = A_ptrs_alloc.get(); + float ** X_ptrs_dev = X_ptrs_alloc.get(); + + get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02, + total_batches, s02, s03, s2, s3); + + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + + // Yes, this is necessary, without this we get RMSE errors + CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH)); + CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N, + CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches)); + + // revert to standard mode from common.cuh + CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH)); + + GGML_UNUSED_VARS(s12, s13); +} // ====================== // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction @@ -63,7 +137,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; - const int half = WARP_SIZE; + const int half = WARP_SIZE; const int nrows_low = (n < half) ? n : half; #pragma unroll @@ -81,8 +155,8 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, #pragma unroll for (int row = half; row < n; ++row) { - float sum = sA[row * n + lane] * x_low; - const int j = half + lane; + float sum = sA[row * n + lane] * x_low; + const int j = half + lane; if (j < row) { sum += sA[row * n + j] * x_high; } @@ -97,7 +171,7 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, for (int rr = 0; rr < 2; ++rr) { const int row = rr * WARP_SIZE + lane; if (row < n) { - const float val = (row < half) ? x_low : x_high; + const float val = (row < half) ? x_low : x_high; X_batch[row * k + col_idx] = val; } } @@ -176,20 +250,26 @@ static void solve_tri_f32_cuda(const float * A, } void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix) - const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns) + const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular) + const ggml_tensor * src1 = dst->src[1]; // B (n×k) ggml_is_contiguous(src0); ggml_is_contiguous(src1); - const int64_t n = src0->ne[0]; - const int64_t k = src1->ne[0]; + const int64_t n = src0->ne[0]; + const int64_t k = src1->ne[0]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; - GGML_ASSERT(n <= 64); - GGML_ASSERT(k <= 32); - - solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2], - src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), - src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), - dst->nb[3] / sizeof(float), ctx.stream()); + if (n <= MAX_N_FAST && k <= MAX_K_FAST) { + solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, + src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), + dst->nb[3] / sizeof(float), ctx.stream()); + } else { + solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, + ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), + dst->nb[3] / sizeof(float), ctx.stream()); + } } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h index 5ad5623a..d89e35a8 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h +++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h @@ -21,6 +21,9 @@ #define CUDA_R_16F HIPBLAS_R_16F #define CUDA_R_16BF HIPBLAS_R_16B #define CUDA_R_32F HIPBLAS_R_32F +#define CUBLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT +#define CUBLAS_FILL_MODE_UPPER HIPBLAS_FILL_MODE_UPPER +#define CUBLAS_DIAG_NON_UNIT HIPBLAS_DIAG_NON_UNIT #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended #define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned @@ -32,6 +35,7 @@ #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define __all_sync(mask, var) __all(var) #define __any_sync(mask, var) __any(var) +#define cublasStrsmBatched hipblasStrsmBatched #define cublasCreate hipblasCreate #define cublasDestroy hipblasDestroy #define cublasGemmEx hipblasGemmEx diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h index 8c55a2e4..221e67f9 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h +++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h @@ -12,11 +12,16 @@ #define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT #define CUBLAS_OP_N MUBLAS_OP_N #define CUBLAS_OP_T MUBLAS_OP_T +#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH +#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT +#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER +#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS #define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH #define CUDA_R_16F MUSA_R_16F #define CUDA_R_16BF MUSA_R_16BF #define CUDA_R_32F MUSA_R_32F +#define cublasStrsmBatched mublasStrsmBatched #define cublasComputeType_t cudaDataType_t #define cublasCreate mublasCreate #define cublasDestroy mublasDestroy diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m index 7b7d1c12..f24270bb 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m @@ -769,9 +769,16 @@ ggml_metal_device_t ggml_metal_device_init(void) { #endif dev->props.use_shared_buffers = dev->props.has_unified_memory; +#if TARGET_OS_OSX + // In case of eGPU, shared memory may be preferable. + dev->props.use_shared_buffers |= [dev->mtl_device location] == MTLDeviceLocationExternal; +#endif if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) { dev->props.use_shared_buffers = false; } + if (getenv("GGML_METAL_SHARED_BUFFERS_ENABLE") != NULL) { + dev->props.use_shared_buffers = true; + } dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8a83427f..df79f9f7 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -661,6 +661,7 @@ struct vk_device_struct { vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_log[2]; vk_pipeline pipeline_tri[2]; + vk_pipeline pipeline_diag[2]; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; @@ -724,6 +725,11 @@ struct vk_device_struct { vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_soft_max_back_f32; + + vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16; + vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16; + vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16; + vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16; vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; @@ -759,7 +765,8 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_split_k_reduce; - vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT]; + // [2] is for whether to take n_experts from spec constant (0) or push constant (1) + vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2]; std::vector all_pipelines; @@ -1151,6 +1158,7 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); struct vk_op_topk_moe_push_constants { uint32_t n_rows; + uint32_t n_experts_push; uint32_t n_expert_used; float clamp_min; float clamp_max; @@ -3732,6 +3740,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -3919,6 +3928,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); @@ -3998,6 +4010,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32, "soft_max_large1_f32", soft_max_large1_f32_len, soft_max_large1_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32, "soft_max_large2_f32", soft_max_large2_f32_len, soft_max_large2_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32, "soft_max_large3_f32", soft_max_large3_f32_len, soft_max_large3_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -4206,10 +4225,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); - for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + for (uint32_t use_push = 0; use_push < 2; ++use_push) { + for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + } } for (auto &c : compiles) { @@ -8276,6 +8297,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const switch (op) { case GGML_OP_GET_ROWS: GGML_ASSERT(src1->type == GGML_TYPE_I32); + if (src0->type == GGML_TYPE_I32) { + // i32 src only supports i32 result + GGML_ASSERT(dst->type == GGML_TYPE_I32); + return ctx->device->pipeline_get_rows[src0->type]; + } if (dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_get_rows[src0->type]; } @@ -8402,6 +8428,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16]; } return nullptr; + case GGML_OP_DIAG: + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16]; + } + return nullptr; case GGML_OP_CLAMP: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_clamp_f32; @@ -8556,7 +8588,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); GGML_ASSERT(idx < num_topk_moe_pipelines); topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); - return ctx->device->pipeline_topk_moe[idx][mode]; + // use n_experts from push constant if it's not equal to the power of two spec constant + bool use_push = dst->ne[0] != (1u << idx); + return ctx->device->pipeline_topk_moe[idx][mode][use_push]; } if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { @@ -9093,6 +9127,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_COS: case GGML_OP_LOG: case GGML_OP_TRI: + case GGML_OP_DIAG: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_ROLL: @@ -9780,6 +9815,12 @@ static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p)); } +static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p)); +} + static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); p.param1 = ggml_get_op_params_f32(dst, 0); @@ -10113,7 +10154,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, { + vk_op_soft_max_push_constants pc { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], @@ -10124,7 +10165,55 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, n_head_log2, nrows_x, src2 != nullptr - }); + }; + + if (ncols <= 16384) { + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc)); + } else { + + vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a; + vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a; + vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst); + + uint32_t elems_per_wg = 128 * 4; + uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg); + size_t tmp_size = num_wgs * nrows_x * sizeof(float); + + if (ctx->prealloc_size_x < tmp_size) { + ctx->prealloc_size_x = tmp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_size_y < tmp_size) { + ctx->prealloc_size_y = tmp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size }; + vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size }; + + std::array elements = { num_wgs, nrows_x, 1 }; + + vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32; + vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32; + vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32; + + ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements); + + ctx->prealloc_x_need_sync = true; + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -10160,6 +10249,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, vk_op_topk_moe_push_constants pc {}; pc.n_rows = n_rows; + pc.n_experts_push = n_experts; pc.n_expert_used = n_expert_used; if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; @@ -11859,6 +11949,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_TRI: ggml_vk_tri(ctx, compute_ctx, src0, node); + break; + case GGML_OP_DIAG: + ggml_vk_diag(ctx, compute_ctx, src0, node); + break; case GGML_OP_CLAMP: ggml_vk_clamp(ctx, compute_ctx, src0, node); @@ -12859,8 +12953,7 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc } const int n_expert = softmax->ne[0]; - // n_expert must be a power of 2 - if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) { + if (n_expert > (1 << (num_topk_moe_pipelines-1))) { return false; } @@ -13999,6 +14092,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_I32: return true; default: return false; @@ -14123,6 +14217,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LOG: case GGML_OP_TRI: + case GGML_OP_DIAG: return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && op->type == op->src[0]->type; case GGML_OP_ARGSORT: @@ -14751,6 +14846,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_log(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_TRI) { tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0)); + } else if (tensor->op == GGML_OP_DIAG) { + tensor_clone = ggml_diag(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_CLAMP) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp new file mode 100644 index 00000000..cd3f42f4 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "rte.glsl" +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + + if (i10 == i11) { + const float val = float(data_a[get_aoffset() + i13*p.nb03 + i12*p.nb02 + 0*p.nb01 + i10*p.nb00]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val); + } else { + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 4bef48b0..0379e5d5 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -256,6 +256,9 @@ void main() { barrier(); } + // prevent race on tmpsh + barrier(); + // reduce across threads [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index cd82e4ab..c995ab14 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -302,6 +302,9 @@ void main() { barrier(); } + // prevent race on tmpsh + barrier(); + // reduce across threads float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp index 76d83041..e88bdd05 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -26,9 +26,9 @@ void main() { const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; #if defined(DATA_A_BF16) - FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); + TEMP_TYPE v = TEMP_TYPE(bf16_to_fp32(data_a[a_offset + i00])); #else - FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); + TEMP_TYPE v = TEMP_TYPE(data_a[a_offset + i00]); #endif #ifndef OPTIMIZATION_ERROR_WORKAROUND data_d[d_offset + i00] = D_TYPE(v); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp index 0b74b332..c5f5e9cb 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -7,34 +7,50 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; -void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { - const uint y_idx = i * QUANT_K + 32 * ib32; - - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const float d = float(data_a[ibi].d); - const uint qh = data_a[ibi].qh[ib32]; - const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); - const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; - +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, + const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx_base = i * QUANT_K + 32 * ib32; + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint base_b_idx = (j * p.batch_stride_b + b_offset + y_idx_base) / 4; [[unroll]] for (uint l = 0; l < 4; ++l) { - const uint qs = data_a[ibi].qs[4 * ib32 + l]; - const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3); - const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + const vec4 b_val_0 = vec4(data_b_v4[base_b_idx + 2 * l]); + const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]); - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); - vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + // index for data_a + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint qh = data_a[ibi].qh[ib32]; + + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3); + const uint16_t grid = uint16_t(iq1s_grid[qs | (idxhi << 8)]); + + const float delta_val = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const vec4 delta_v = vec4(delta_val); + const vec4 fbits0 = vec4( + float(bitfieldExtract(grid, 0, 2)), + float(bitfieldExtract(grid, 2, 2)), + float(bitfieldExtract(grid, 4, 2)), + float(bitfieldExtract(grid, 6, 2)) + ); + const vec4 fbits1 = vec4( + float(bitfieldExtract(grid, 8, 2)), + float(bitfieldExtract(grid, 10, 2)), + float(bitfieldExtract(grid, 12, 2)), + float(bitfieldExtract(grid, 14, 2)) + ); + + vec4 sum_v = fma(b_val_0, fbits0 + delta_v, vec4(0.0)); + sum_v = fma(b_val_1, fbits1 + delta_v, sum_v); + FLOAT_TYPE sum = dot(sum_v, vec4(1.0)); - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int k = 0; k < 4; ++k) { - sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, - fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); - } temp[j][n] = fma(dl, sum, temp[j][n]); + ibi += num_blocks_per_row; } } - ibi += num_blocks_per_row; } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index ee5ded2e..58ede044 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -244,17 +244,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint iqs = idx % 128; // 0..127 const uint n = iqs / 64; // 0,1 - const uint b = (iqs % 64) / 32; // 0,1 + const uint b = ((iqs % 64) / 32) * 4; // 0,4 const uint is_b = (iqs % 16) / 8; // 0,1 const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 const uint is = 8 * n + qhshift + is_b; // 0..15 - const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 - const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint qsi = n * 32 + (iqs % 32); // 0..63 + const uint qhi = n * 16 + (iqs % 16); // 0..31 const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32), - dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); + const uint ql = (uint(data_a_packed16[ib].ql[qsi]) >> b) & 0x0F0F; + const uint qh = (uint(data_a_packed16[ib].qh[qhi]) >> qhshift) & 0x0303; + const vec2 q = (vec2(unpack8(ql | (qh << 4)).xy) - 32) * dscale; + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(q.x, q.y); #elif defined(DATA_A_IQ1_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp new file mode 100644 index 00000000..39c46639 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp @@ -0,0 +1,62 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = get_slope(rowx); + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + FLOAT_TYPE a = FLOAT_TYPE(0); + if (col < p.KX) { + a = data_a[rowx * p.KX + col]; + } + + FLOAT_TYPE b = FLOAT_TYPE(0); + if (p.KY > 0 && col < p.KX) { + b = data_b[rowy_start + col]; + } + + FLOAT_TYPE v = a * p.scale + slope * b; + + if (col < p.KX) { + max_val = max(max_val, v); + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + if (tid == 0) { + max_val = vals[0]; + data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp new file mode 100644 index 00000000..69524f5f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp @@ -0,0 +1,79 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = get_slope(rowx); + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) { + if (i + tid < gl_NumWorkGroups.x) { + max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]); + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(max_val, vals[tid + s]); + } + barrier(); + } + + max_val = vals[0]; + barrier(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + // compute exp(a*scale+b*slope), add it to sum + const uint i = rowx * p.KX + col; + FLOAT_TYPE val; + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); + sum += val; + data_d[i] = D_TYPE(val); + } + + // reduce across the workgroup + vals[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + + if (tid == 0) { + sum = vals[0]; + data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp new file mode 100644 index 00000000..06efd7d9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp @@ -0,0 +1,65 @@ +#version 450 + +#include "soft_max_large_common.glsl" + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.y; + const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) { + if (i + tid < gl_NumWorkGroups.x) { + max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]); + sum += data_s[rowx * gl_NumWorkGroups.x + i + tid]; + } + } + + // reduce across the workgroup + vals[tid] = max_val; + sumsh[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(max_val, vals[tid + s]); + sumsh[tid] += sumsh[tid + s]; + } + barrier(); + } + + max_val = vals[0]; + sum = sumsh[0]; + + if (p.has_sinks != 0) { + sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val)); + } + + FLOAT_TYPE rcpdivisor = 1.0/sum; + + [[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + continue; + } + + data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl new file mode 100644 index 00000000..6636d1f8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl @@ -0,0 +1,53 @@ +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + uint ne00; + uint ne01; + uint ne02; + uint ne12; + uint ne13; + uint nb11; + uint nb12; + uint nb13; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + uint nrows_x; + uint has_sinks; +} p; + +#include "types.glsl" + +layout(constant_id = 0) const uint BLOCK_SIZE = 128; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout(constant_id = 1) const uint num_iters = 4; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) readonly buffer Z {float data_c[];}; +layout (binding = 3) buffer D {D_TYPE data_d[];}; +layout (binding = 4) buffer M {float data_m[];}; +layout (binding = 5) buffer S {float data_s[];}; + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +float get_slope(uint rowx) { + float slope = 1.0f; + + // ALiBi + if (p.max_bias > 0.0f) { + const uint h = (rowx / p.ne01) % p.ne02; // head index + + const float base = h < p.n_head_log2 ? p.m0 : p.m1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + + slope = pow(base, exp); + } + + return slope; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index 5cd0785d..b83a2b9d 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -10,6 +10,7 @@ layout (push_constant) uniform parameter { uint n_rows; + uint n_experts_push; uint n_expert_used; float clamp_min; float clamp_max; @@ -18,11 +19,16 @@ layout (push_constant) uniform parameter layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; layout(constant_id = 0) const uint WARP_SIZE = 32; -layout(constant_id = 1) const uint n_experts = 512; +layout(constant_id = 1) const uint n_experts_spec = 512; layout(constant_id = 2) const bool with_norm = true; layout(constant_id = 3) const bool late_softmax = false; +layout(constant_id = 4) const bool nexperts_use_push = false; -const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; +uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE); layout (binding = 0, std430) readonly buffer Logits {float logits[];}; layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; @@ -94,7 +100,7 @@ void main() { } if (!late_softmax) { - softmax_warp_inplace(wt, n_experts, lane, false); + softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push); } // at this point, each thread holds a portion of softmax, diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 92bae088..b0ade078 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -704,13 +704,15 @@ void process_shaders() { shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; if (tname == "f16") { - string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); } else { - string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); } - string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); + string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); } + string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}}); + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); @@ -854,6 +856,8 @@ void process_shaders() { string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("diag_f16", "diag.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("diag_f32", "diag.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -899,6 +903,13 @@ void process_shaders() { string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large1_f32", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large2_f32", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large3_f32", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large1_f32_f16", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large2_f32_f16", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_large3_f32_f16", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); diff --git a/ml/backend/ggml/ggml/src/ggml.c b/ml/backend/ggml/ggml/src/ggml.c index fc0196eb..c9242a15 100644 --- a/ml/backend/ggml/ggml/src/ggml.c +++ b/ml/backend/ggml/ggml/src/ggml.c @@ -5265,8 +5265,6 @@ struct ggml_tensor * ggml_flash_attn_ext( if (mask) { GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && - "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); GGML_ASSERT(q->ne[2] % mask->ne[2] == 0); @@ -7573,6 +7571,11 @@ size_t ggml_quantize_chunk( //////////////////////////////////////////////////////////////////////////////// +void ggml_log_get(ggml_log_callback * log_callback, void ** user_data) { + *log_callback = g_logger_state.log_callback; + *user_data = g_logger_state.log_callback_user_data; +} + void ggml_log_set(ggml_log_callback log_callback, void * user_data) { g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default; g_logger_state.log_callback_user_data = user_data; diff --git a/ml/nn/rope/options.go b/ml/nn/rope/options.go index 84b92677..1724128a 100644 --- a/ml/nn/rope/options.go +++ b/ml/nn/rope/options.go @@ -77,6 +77,13 @@ func WithMRoPE(sections []int) func(*Options) { } } +func WithVision(sections []int) func(*Options) { + return func(opts *Options) { + opts.Type |= 1<<3 | 1<<4 + opts.MRoPE.Sections = sections + } +} + func WithInterleaveMRoPE(sections []int) func(*Options) { return func(opts *Options) { opts.Type |= 1<<3 | 1<<5 diff --git a/model/bytepairencoding.go b/model/bytepairencoding.go index acb58743..765331bf 100644 --- a/model/bytepairencoding.go +++ b/model/bytepairencoding.go @@ -2,9 +2,7 @@ package model import ( "cmp" - "fmt" "iter" - "log/slog" "slices" "strings" @@ -245,14 +243,6 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { return ids, nil } -type lazyIdsString struct { - ids []int32 -} - -func (l lazyIdsString) LogValue() slog.Value { - return slog.AnyValue(fmt.Sprint(l.ids)) -} - func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { var sb strings.Builder for _, id := range ids { @@ -277,6 +267,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { } } - logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids}) + logutil.Trace("decoded", "string", sb.String(), "from", ids) return sb.String(), nil } diff --git a/model/models/bert/embed.go b/model/models/bert/embed.go index 5e7ca5e9..79cb3a3c 100644 --- a/model/models/bert/embed.go +++ b/model/models/bert/embed.go @@ -129,35 +129,34 @@ func (o Options) headDim() int { } func New(c fs.Config) (model.Model, error) { + vocab := &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{ + int32(cmp.Or( + c.Uint("tokenizer.ggml.cls_token_id"), + c.Uint("tokenizer.ggml.bos_token_id"), + )), + }, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true), + EOS: []int32{ + int32(cmp.Or( + c.Uint("tokenizer.ggml.separator_token_id"), + //nolint:misspell + // NOTE: "seperator_token_id" is a typo in model metadata but we need to + // support it for compatibility. + c.Uint("tokenizer.ggml.seperator_token_id"), + c.Uint("tokenizer.ggml.eos_token_id"), + )), + }, + } + var processor model.TextProcessor switch c.String("tokenizer.ggml.model", "bert") { case "bert": - processor = model.NewWordPiece( - &model.Vocabulary{ - Values: c.Strings("tokenizer.ggml.tokens"), - Scores: c.Floats("tokenizer.ggml.scores"), - Types: c.Ints("tokenizer.ggml.token_type"), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - BOS: []int32{ - int32(cmp.Or( - c.Uint("tokenizer.ggml.cls_token_id"), - c.Uint("tokenizer.ggml.bos_token_id"), - )), - }, - AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true), - EOS: []int32{ - int32(cmp.Or( - c.Uint("tokenizer.ggml.separator_token_id"), - //nolint:misspell - // NOTE: "seperator_token_id" is a typo in model metadata but we need to - // support it for compatibility. - c.Uint("tokenizer.ggml.seperator_token_id"), - c.Uint("tokenizer.ggml.eos_token_id"), - )), - }, - }, - true, - ) + processor = model.NewWordPiece(vocab, true) default: return nil, model.ErrUnsupportedTokenizer } diff --git a/model/models/models.go b/model/models/models.go index 85bf9a7f..b471e816 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -13,6 +13,7 @@ import ( _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" _ "github.com/ollama/ollama/model/models/nomicbert" + _ "github.com/ollama/ollama/model/models/olmo3" _ "github.com/ollama/ollama/model/models/qwen2" _ "github.com/ollama/ollama/model/models/qwen25vl" _ "github.com/ollama/ollama/model/models/qwen3" diff --git a/model/models/olmo3/model.go b/model/models/olmo3/model.go new file mode 100644 index 00000000..523c00e6 --- /dev/null +++ b/model/models/olmo3/model.go @@ -0,0 +1,223 @@ +package olmo3 + +import ( + "fmt" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +const ( + cacheTypeSWA = 0 + cacheTypeCausal = 1 +) + +type Options struct { + hiddenSize, numHeads, numKVHeads int + eps, ropeBase, ropeScale float32 + + originalContextLength int + attnFactor float32 + + ropeType string + ropeExtrapolation float32 + + slidingWindowPattern []bool +} + +type Model struct { + model.Base + model.TextProcessor + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Options +} + +func New(c fs.Config) (model.Model, error) { + vocabulary := model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + } + + processor := model.NewBytePairEncoding( + &vocabulary, + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + ) + + m := Model{ + TextProcessor: processor, + Layers: make([]Layer, c.Uint("block_count")), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base", 1e4), + ropeScale: c.Float("rope.scaling.factor", 1), + originalContextLength: int(c.Uint("rope.scaling.original_context_length")), + attnFactor: c.Float("rope.scaling.attn_factor", 1), + ropeType: c.String("rope.scaling.type"), + ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0), + slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), + }, + } + + m.Cache = kvcache.NewWrapperCache( + kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift), + kvcache.NewCausalCache(m.Shift), + ) + + return &m, nil +} + +type SelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` + QNorm *nn.RMSNorm `gguf:"attn_q_norm"` + KNorm *nn.RMSNorm `gguf:"attn_k_norm"` +} + +func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, isSWA bool) ml.Tensor { + freqScale := float32(1.0) + ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()} + + if !isSWA { + freqScale = 1. / o.ropeScale + if o.originalContextLength > 0 { + ropeOpts = append(ropeOpts, + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithExtrapolationFactor(o.ropeExtrapolation), + ) + } + } + + return nn.RoPE(ctx, states, positions, o.hiddenSize/o.numHeads, o.ropeBase, freqScale, ropeOpts...) +} + +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor { + batchSize := hiddenState.Dim(1) + headDim := m.hiddenSize / m.numHeads + + query := sa.Query.Forward(ctx, hiddenState) + query = sa.QNorm.Forward(ctx, query, m.eps) + query = query.Reshape(ctx, headDim, m.numHeads, batchSize) + query = m.Options.applyRotaryPositionEmbeddings(ctx, query, positions, isSWA) + + key := sa.Key.Forward(ctx, hiddenState) + key = sa.KNorm.Forward(ctx, key, m.eps) + key = key.Reshape(ctx, headDim, m.numKVHeads, batchSize) + key = m.Options.applyRotaryPositionEmbeddings(ctx, key, positions, isSWA) + + value := sa.Value.Forward(ctx, hiddenState) + value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize) + + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention = attention.Reshape(ctx, m.hiddenSize, batchSize) + + return sa.Output.Forward(ctx, attention) +} + +func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + isSWA := m.isSWALayer(layer) + return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift, isSWA), nil +} + +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` + Gate *nn.Linear `gguf:"ffn_gate"` +} + +func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, m *Model) ml.Tensor { + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) + return mlp.Down.Forward(ctx, hiddenState) +} + +type Layer struct { + SelfAttention *SelfAttention + PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"` + MLP *MLP + PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"` +} + +func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor { + residual := hiddenState + + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA) + + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, m.eps) + + hiddenState = hiddenState.Add(ctx, residual) + residual = hiddenState + + hiddenState = l.MLP.Forward(ctx, hiddenState, m) + hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, m.eps) + + return hiddenState.Add(ctx, residual) +} + +// OLMo3 has Sliding Window Attention (SWA) for 3 out of every 4 layers. +func (m *Model) isSWALayer(layerIdx int) bool { + return m.Options.slidingWindowPattern[layerIdx] +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) + + hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + cacheType := cacheTypeSWA + + isSWA := m.isSWALayer(i) + if !isSWA { + cacheType = cacheTypeCausal + } + + wc, ok := m.Cache.(*kvcache.WrapperCache) + if !ok { + return nil, fmt.Errorf("expected *kvcache.WrapperCache, got %T", m.Cache) + } + wc.SetLayerType(cacheType) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = batch.Outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m, isSWA) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + return m.Output.Forward(ctx, hiddenState), nil +} + +func init() { + model.Register("olmo3", New) +} diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 13fa3fee..81296a81 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -2,7 +2,6 @@ package qwen25vl import ( "bytes" - "fmt" "image" "slices" @@ -33,7 +32,7 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), EOS: append( @@ -54,19 +53,18 @@ func New(c fs.Config) (model.Model, error) { } func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) { - image, _, err := image.Decode(bytes.NewReader(multimodalData)) + img, _, err := image.Decode(bytes.NewReader(multimodalData)) if err != nil { return nil, nil, err } - f32s, grid, err := m.ImageProcessor.ProcessImage(image) + f32s, grid, err := m.ImageProcessor.ProcessImage(img) if err != nil { return nil, nil, err } // Calculate tensor dimensions - patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize * - m.ImageProcessor.patchSize * m.ImageProcessor.patchSize + patchDim := m.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize numPatches := grid.Temporal * grid.Height * grid.Width pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches) @@ -85,11 +83,13 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input } visionOutputs := m.VisionModel.Forward(ctx, pixels, grid) - return []input.Multimodal{{Tensor: visionOutputs}}, nil + return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil } // PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + // Reset position cache + m.positionCache = m.positionCache[:0] var result []*input.Input var ( @@ -98,40 +98,37 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { visionEndToken int32 = 151653 ) - nImg := 0 + appendInput := func(i *input.Input, p int) int { + result = append(result, i) + m.positionCache = append(m.positionCache, int32(p)) + return p + 1 + } + + var p int for _, inp := range inputs { if inp.Multimodal == nil { // If not a multimodal input, add it to the result unchanged - result = append(result, inp) + p = appendInput(inp, p) } else { - // Adding the 'Picture' prefix is a hack, at the time of writing there is no way to prefix - // the image tokens with a prompt, so we add a prefix here - nImg++ - pre, err := m.Encode(fmt.Sprintf(" Picture %d: ", nImg), true) - if err != nil { - return nil, fmt.Errorf("failed to encode image prompt: %w", err) - } - for i := range pre { - result = append(result, &input.Input{Token: pre[i]}) - } - - patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1) - // First add the vision start token - result = append(result, &input.Input{Token: visionStartToken}) + p = appendInput(&input.Input{Token: visionStartToken}, p) // Add the image token with the multimodal tensor data at the first position - result = append(result, &input.Input{ + tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1) + appendInput(&input.Input{ Token: imageToken, Multimodal: inp.Multimodal, MultimodalHash: inp.MultimodalHash, - SameBatch: patchesPerChunk, - }) + SameBatch: tokensPerGrid, + }, p) // Add the placeholder tokens for the remaining positions (tokensPerGrid-1) - result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...) + for range tokensPerGrid - 1 { + appendInput(&input.Input{Token: imageToken}, p) + } - result = append(result, &input.Input{Token: visionEndToken}) + grid := inp.Multimodal[0].Data.(*Grid) + p = appendInput(&input.Input{Token: visionEndToken}, p+max(grid.Width/m.spatialMergeSize, grid.Height/m.spatialMergeSize)) } } @@ -139,9 +136,58 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) + // Initial token embedding + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx) - return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache) + positionSlice := func() [][]int32 { + s := [][]int32{ + make([]int32, len(batch.Positions)), + make([]int32, len(batch.Positions)), + make([]int32, len(batch.Positions)), + make([]int32, len(batch.Positions)), + } + for i, position := range batch.Positions { + if position < int32(len(m.positionCache)) { + position = m.positionCache[position] + } else if len(m.positionCache) > 0 { + position = position - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1 + } + + s[0][i] = position + s[1][i] = position + s[2][i] = position + } + return s + }() + + for _, mi := range batch.Multimodal { + img := mi.Multimodal[0].Tensor + ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1)))) + if grid, ok := mi.Multimodal[0].Data.(*Grid); ok { + for i := range img.Dim(1) { + w := grid.Width / m.spatialMergeSize + positionSlice[1][mi.Index+i] += int32(i / w) + positionSlice[2][mi.Index+i] += int32(i % w) + } + } + } + + positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice)) + + // Process through transformer layers + for i, layer := range m.TextModel.Layers { + m.Cache.SetLayer(i) + + var lastLayerOutputs ml.Tensor + if i == len(m.TextModel.Layers)-1 { + lastLayerOutputs = batch.Outputs + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextOptions) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps) + return m.Output.Forward(ctx, hiddenStates), nil } func init() { diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index b4db6043..61b072d6 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -8,20 +8,17 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn/rope" - "github.com/ollama/ollama/model/input" ) type TextOptions struct { hiddenSize, numHeads, numKVHeads int ropeDim, originalContextLength int eps, ropeBase, ropeScale float32 + mropeSections []int } func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { - return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, - rope.WithOriginalContextLength(o.originalContextLength), - rope.WithTypeNeoX(), - ) + return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithMRoPE(o.mropeSections)) } type TextModel struct { @@ -31,6 +28,7 @@ type TextModel struct { Output *nn.Linear `gguf:"output,alt:token_embd"` *TextOptions + positionCache []int32 } func NewTextModel(c fs.Config) *TextModel { @@ -45,6 +43,14 @@ func NewTextModel(c fs.Config) *TextModel { eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.scaling.factor", 1), + mropeSections: func() []int { + sections := c.Ints("rope.mrope_section") + s := make([]int, len(sections)) + for i, section := range sections { + s[i] = int(section) + } + return s + }(), }, } @@ -84,6 +90,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten // Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + m.positionCache = nil return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil } @@ -130,28 +137,3 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten hiddenState = l.MLP.Forward(ctx, hiddenState, opts) return hiddenState.Add(ctx, residual) } - -func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) { - // Initial token embedding - hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) - - for _, mi := range batch.Multimodal { - img := mi.Multimodal[0].Tensor - ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1)))) - } - - // Process through transformer layers - for i, layer := range m.Layers { - cache.SetLayer(i) - - var lastLayerOutputs ml.Tensor - if i == len(m.Layers)-1 { - lastLayerOutputs = outputs - } - - hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, m.TextOptions) - } - - hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) - return m.Output.Forward(ctx, hiddenStates), nil -} diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index bfdafabe..f1275437 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -7,48 +7,28 @@ import ( "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/rope" ) -// We only support batch size of 1 -var batchSize int = 1 - -func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { - x1 := t.Slice(ctx, 0, 0, t.Dim(0)/2, 1) - x2 := t.Slice(ctx, 0, t.Dim(0)/2, t.Dim(0), 1).Contiguous(ctx) - return x2.Scale(ctx, -1).Concat(ctx, x1, 0) -} - -func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor { - return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin)) -} - -func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor { - // Create a flat slice for the mask (all -inf initially to block all attention) - flat := make([]float32, seqLength*seqLength) - for i := range flat { - flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention +func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor { + // Initialize a 2D mask with -Inf + s := make([][]float32, seqLength) + for i := range s { + s[i] = slices.Repeat([]float32{float32(math.Inf(-1))}, seqLength) } // Fill in the mask with zeros for tokens that CAN attend to each other for i := 1; i < len(bounds); i++ { - start := bounds[i-1] - end := bounds[i] - - // Enable attention within this sequence block by setting values to 0 + start, end := bounds[i-1], bounds[i] + // Enable attention within this sequence block for row := start; row < end; row++ { for col := start; col < end; col++ { - idx := row*seqLength + col - flat[idx] = 0.0 // 0 allows attention, -inf blocks it + s[row][col] = 0.0 } } } - mask := ctx.Input().FromFloats(flat, seqLength, seqLength) - - // Reshape to match [seqLength, seqLength, 1] for broadcasting - mask = mask.Reshape(ctx, seqLength, seqLength, 1) - - return mask + return ctx.Input().FromFloats(slices.Concat(s...), seqLength, seqLength) } type VisionSelfAttention struct { @@ -58,17 +38,17 @@ type VisionSelfAttention struct { Output *nn.Linear `gguf:"attn_out"` } -func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { +func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { query := sa.Query.Forward(ctx, hiddenStates) key := sa.Key.Forward(ctx, hiddenStates) value := sa.Value.Forward(ctx, hiddenStates) - query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize) - key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) - value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) + query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1)) + key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1)) + value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1)) - query = applyRotaryPositionEmbeddings(ctx, query, cos, sin) - key = applyRotaryPositionEmbeddings(ctx, key, cos, sin) + query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) + key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) // Scale factor for scaled dot-product attention scale := 1.0 / math.Sqrt(float64(opts.headDim)) @@ -77,6 +57,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m query = query.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + kq := key.MulmatFullPrec(ctx, query) kq = kq.Scale(ctx, scale) if mask != nil { @@ -85,7 +66,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m kq = kq.Softmax(ctx) kqv := value.Mulmat(ctx, kq) attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) + attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2)) return sa.Output.Forward(ctx, attention) } @@ -98,10 +79,7 @@ type VisionMLP struct { } func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { - // Using activation as specified in config (likely GELU or SiLU/Swish) - gateOutput := mlp.Gate.Forward(ctx, hiddenStates) - hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) - + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } @@ -112,10 +90,10 @@ type VisionEncoderLayer struct { MLP *VisionMLP } -func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { +func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, positions, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor { residual := hiddenStates hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps) - hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts) + hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, positions, mask, opts) hiddenStates = hiddenStates.Add(ctx, residual) residual = hiddenStates @@ -139,6 +117,17 @@ type VisionModelOptions struct { temporalPatchSize int } +func (o VisionModelOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { + return nn.RoPE(ctx, states, positions, o.headDim/2, o.ropeTheta, 1, + rope.WithVision([]int{ + o.headDim / 4, + o.headDim / 4, + o.headDim / 4, + o.headDim / 4, + }), + ) +} + type PatchEmbedding struct { PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"` PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"` @@ -186,7 +175,7 @@ func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, op hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize) // Reshape the normalized output to view the hidden size dimension - reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize) + reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize)) hidden := pm.MLP0.Forward(ctx, reshaped) activated := hidden.GELU(ctx) @@ -209,36 +198,53 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) // Extract patch embeddings hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions) - positionEmbedding := m.PositionalEmbedding(ctx, grid) - - windowIndex, bounds := m.WindowIndex(ctx, grid) - + index, bounds := m.windowIndex(grid) spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize + windowIndex := ctx.Input().FromInts(index, len(index)) hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit) - hiddenStates = hiddenStates.Rows(ctx, windowIndex) + hiddenStates = hiddenStates.Rows(ctx, windowIndex.Argsort(ctx)) hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit) - positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit) - positionEmbedding = positionEmbedding.Rows(ctx, windowIndex) - positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit) - positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0) + positions := ctx.Input().FromInts(func() []int32 { + s := [][]int32{ + make([]int32, grid.Height*grid.Width), + make([]int32, grid.Height*grid.Width), + make([]int32, grid.Height*grid.Width), + make([]int32, grid.Height*grid.Width), + } - cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) - cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1)) - sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1)) + var cur int + for y := 0; y < grid.Height; y += m.spatialMergeSize { + for x := 0; x < grid.Width; x += m.spatialMergeSize { + for dy := range 2 { + for dx := range 2 { + i := int(index[cur/spatialMergeUnit]) * spatialMergeUnit + i += cur % spatialMergeUnit + s[0][i] = int32(y + dy) + s[1][i] = int32(x + dx) + s[2][i] = int32(y + dy) + s[3][i] = int32(x + dx) + cur++ + } + } + } + } + + return slices.Concat(s...) + }(), grid.Height*grid.Width*4) + + mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds) - mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads) // Apply encoder layers for i, layer := range m.Layers { if slices.Contains(m.fullAttnBlocks, int32(i)) { - hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions) + hiddenStates = layer.Forward(ctx, hiddenStates, positions, nil, m.VisionModelOptions) } else { hiddenStates = layer.Forward( ctx, hiddenStates, - cos, - sin, + positions, mask, m.VisionModelOptions, ) @@ -246,102 +252,43 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) } hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions) - reverseWindowIndex := windowIndex.Argsort(ctx) - return hiddenStates.Rows(ctx, reverseWindowIndex) + return hiddenStates.Rows(ctx, windowIndex) } -// WindowIndex divides the grid into windows and returns: -// 1. A tensor containing flattened indices of all grid points organized by windows +// windowIndex divides the grid into windows and returns: +// 1. A slice of grid point indices organized by windows // 2. A slice of boundaries that mark where each window's data begins and ends // in the flattened representation, scaled by spatialMergeSize squared // // The boundaries slice always starts with 0 and contains cumulative ending // positions for each window, allowing downstream processing to identify // window boundaries in the tensor data. -func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) { - vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize +func (m *VisionModel) windowIndex(grid *Grid) (index []int32, bounds []int) { + height := grid.Height / m.spatialMergeSize + width := grid.Width / m.spatialMergeSize + window := m.windowSize / m.patchSize / m.spatialMergeSize - llmGridH := grid.Height / m.spatialMergeSize - llmGridW := grid.Width / m.spatialMergeSize + index = make([]int32, height*width) - // Calculate window parameters - numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize))) - numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize))) + bounds = make([]int, 0, ((height+window-1)/window)*((width+window-1)/window)+1) + bounds = append(bounds, 0) - // Initialize index_new slice - var index []int32 - - // Initialize bounds with the first element as 0 - bounds := []int{0} - totalSeqLen := 0 - - // Process each window without padding - for wh := range numWindowsH { - for ww := range numWindowsW { - // Calculate window boundaries - hStart := wh * vitMergerWindowSize - wStart := ww * vitMergerWindowSize - hEnd := min(hStart+vitMergerWindowSize, llmGridH) - wEnd := min(wStart+vitMergerWindowSize, llmGridW) - - // Calculate sequence length for this window - seqLen := (hEnd - hStart) * (wEnd - wStart) - - // Collect indices for this window - for h := hStart; h < hEnd; h++ { - for w := wStart; w < wEnd; w++ { - index = append(index, int32(h*llmGridW+w)) + var cur int32 + for y := 0; y < height; y += window { + for x := 0; x < width; x += window { + h1 := min(window, height-y) + w1 := min(window, width-x) + for dy := range h1 { + for dx := range w1 { + win := (y+dy)*width + (x + dx) + index[win] = cur + cur++ } } - - totalSeqLen += seqLen - bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0]) + bounds = append(bounds, int(cur)*window) } } - - t := ctx.Input().FromInts(index, len(index)) - - return t, bounds -} - -// PositionalEmbedding generates rotary position embeddings for attention mechanisms -func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor { - dim := m.headDim / 2 - freq := dim / 2 - theta := float64(m.ropeTheta) - merge := m.spatialMergeSize - - // Create frequency patterns for position encoding - maxGridSize := max(grid.Height, grid.Width) - freqVals := make([]float32, freq*maxGridSize) - for i := range maxGridSize { - for j := range freq { - freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim))) - } - } - freqs := ctx.Input().FromFloats(freqVals, freq, maxGridSize) - - // Create position coordinates (y,x pairs) for the grid - // In PyTorch: Equivalent to generating position ids with torch.arange() - coords := make([]int32, 0, grid.Height*grid.Width*2) - for y := range grid.Height { - for x := range grid.Width { - coords = append(coords, int32(y), int32(x)) - } - } - pos := ctx.Input().FromInts(coords, 2, grid.Width, grid.Height) - - // Reshape and permute positions to match spatial merging pattern - pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge) - pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge) - pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge) - - // Use position indices to look up corresponding frequency values - positionalEmbedding := freqs.Rows(ctx, pos) - positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2) - return positionalEmbedding + return index, bounds } // newVisionModel creates a new instance of the Qwen vision model diff --git a/model/models/qwen25vl/process_image.go b/model/models/qwen25vl/process_image.go index ce5ded29..66803d4a 100644 --- a/model/models/qwen25vl/process_image.go +++ b/model/models/qwen25vl/process_image.go @@ -19,8 +19,8 @@ type ImageProcessor struct { maxPixels int factor int rescaleFactor float32 - imageMean []float32 - imageStd []float32 + imageMean [3]float32 + imageStd [3]float32 } // newImageProcessor creates a new image processor with default values @@ -34,11 +34,11 @@ func newImageProcessor(c fs.Config) ImageProcessor { temporalPatchSize: 2, mergeSize: mergeSize, minPixels: 56 * 56, - maxPixels: int(c.Uint("vision.max_pixels", 28*28*1280)), // 1MP limit + maxPixels: int(c.Uint("vision.max_pixels", 2<<20)), // 2M limit factor: patchSize * mergeSize, rescaleFactor: 1.0 / 255.0, - imageMean: imageproc.ClipDefaultMean[:], - imageStd: imageproc.ClipDefaultSTD[:], + imageMean: imageproc.ClipDefaultMean, + imageStd: imageproc.ClipDefaultSTD, } } @@ -90,13 +90,7 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) // Resize image using existing functions resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear) - normalizedPixels := imageproc.Normalize( - resizedImg, - [3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]}, - [3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]}, - true, // rescale - true, // channelFirst - ) + normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true) // Calculate grid dimensions grid := &Grid{ diff --git a/model/parsers/deepseek3.go b/model/parsers/deepseek3.go new file mode 100644 index 00000000..bafc85ba --- /dev/null +++ b/model/parsers/deepseek3.go @@ -0,0 +1,292 @@ +package parsers + +import ( + "encoding/json" + "errors" + "log/slog" + "strings" + "unicode" + + "github.com/ollama/ollama/api" +) + +type DeepSeek3ParserState int + +const ( + DeepSeekCollectingThinking DeepSeek3ParserState = iota + DeepSeekCollectingContent + DeepSeekCollectingToolCalls + DeepSeekCollectingToolOutput +) + +const ( + deepseekThinkingCloseTag = "" + deepseekToolCallsBeginTag = "<|tool▁calls▁begin|>" + deepseekToolCallsEndTag = "<|tool▁calls▁end|>" + deepseekToolCallBeginTag = "<|tool▁call▁begin|>" + deepseekToolCallEndTag = "<|tool▁call▁end|>" + deepseekToolSepTag = "<|tool▁sep|>" + deepseekToolOutputBeginTag = "<|tool▁output▁begin|>" + deepseekToolOutputEndTag = "<|tool▁output▁end|>" +) + +type DeepSeek3Parser struct { + state DeepSeek3ParserState + buffer strings.Builder + hasThinkingSupport bool +} + +func (p *DeepSeek3Parser) HasToolSupport() bool { + return true +} + +func (p *DeepSeek3Parser) HasThinkingSupport() bool { + return p.hasThinkingSupport +} + +func (p *DeepSeek3Parser) setInitialState(lastMessage *api.Message, tools []api.Tool, thinkValue *api.ThinkValue) { + prefill := lastMessage != nil && lastMessage.Role == "assistant" + + // Check both model capability AND request preference + thinkingEnabled := p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool()) + + if !thinkingEnabled { + p.state = DeepSeekCollectingContent + return + } + + if prefill && lastMessage.Content != "" { + p.state = DeepSeekCollectingContent + return + } + + p.state = DeepSeekCollectingThinking +} + +func (p *DeepSeek3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.setInitialState(lastMessage, tools, thinkValue) + return tools +} + +type deepseekEvent interface { + isDeepSeekEvent() +} + +type deepseekEventThinkingContent struct { + content string +} + +type deepseekEventContent struct { + content string +} + +type deepseekEventToolCall struct { + toolCall api.ToolCall +} + +func (deepseekEventThinkingContent) isDeepSeekEvent() {} +func (deepseekEventContent) isDeepSeekEvent() {} +func (deepseekEventToolCall) isDeepSeekEvent() {} + +func (p *DeepSeek3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + events := p.parseEvents() + + var toolCalls []api.ToolCall + var contentSb strings.Builder + var thinkingSb strings.Builder + for _, event := range events { + switch event := event.(type) { + case deepseekEventToolCall: + toolCalls = append(toolCalls, event.toolCall) + case deepseekEventThinkingContent: + thinkingSb.WriteString(event.content) + case deepseekEventContent: + contentSb.WriteString(event.content) + } + } + + return contentSb.String(), thinkingSb.String(), toolCalls, nil +} + +func (p *DeepSeek3Parser) parseEvents() []deepseekEvent { + var all []deepseekEvent + + keepLooping := true + for keepLooping { + var events []deepseekEvent + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + return all +} + +func (p *DeepSeek3Parser) eat() ([]deepseekEvent, bool) { + var events []deepseekEvent + bufStr := p.buffer.String() + if bufStr == "" { + return events, false + } + + switch p.state { + case DeepSeekCollectingThinking: + if strings.Contains(bufStr, deepseekThinkingCloseTag) { // thinking[] -> content + split := strings.SplitN(bufStr, deepseekThinkingCloseTag, 2) + thinking := split[0] + thinking = strings.TrimRightFunc(thinking, unicode.IsSpace) + + remaining := split[1] + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = DeepSeekCollectingContent + + if len(thinking) > 0 { + events = append(events, deepseekEventThinkingContent{content: thinking}) + } + return events, true + } else if overlapLen := overlap(bufStr, deepseekThinkingCloseTag); overlapLen > 0 { // partial + beforePartialTag := bufStr[:len(bufStr)-overlapLen] + trailingLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingLen + + unambiguous := bufStr[:ambiguousStart] + ambiguous := bufStr[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, deepseekEventThinkingContent{content: unambiguous}) + } + return events, false + } else { // otherwise its thinking content + whitespaceLen := trailingWhitespaceLen(bufStr) + ambiguousStart := len(bufStr) - whitespaceLen + + unambiguous := bufStr[:ambiguousStart] + ambiguous := bufStr[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, deepseekEventThinkingContent{content: unambiguous}) + } + return events, false + } + + case DeepSeekCollectingContent: + switch { + case strings.Contains(bufStr, deepseekToolCallsBeginTag): // content[<|tool▁calls▁begin|>] -> tool calls + split := strings.SplitN(bufStr, deepseekToolCallsBeginTag, 2) + contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace) + remaining := split[1] + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = DeepSeekCollectingToolCalls + + if len(contentBefore) > 0 { + events = append(events, deepseekEventContent{content: contentBefore}) + } + return events, true + case strings.Contains(bufStr, deepseekToolOutputBeginTag): // content[<|tool▁output▁begin|>] -> tool output + split := strings.SplitN(bufStr, deepseekToolOutputBeginTag, 2) + contentBefore := split[0] // Don't trim whitespace - preserve spaces + remaining := split[1] + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = DeepSeekCollectingToolOutput + + if len(contentBefore) > 0 { + events = append(events, deepseekEventContent{content: contentBefore}) + } + return events, true + default: // otherwise its content + p.buffer.Reset() + if len(bufStr) > 0 { + events = append(events, deepseekEventContent{content: bufStr}) + } + return events, false + } + + case DeepSeekCollectingToolCalls: + if idx := strings.Index(bufStr, deepseekToolCallBeginTag); idx != -1 { + startIdx := idx + len(deepseekToolCallBeginTag) + if endIdx := strings.Index(bufStr[startIdx:], deepseekToolCallEndTag); endIdx != -1 { + toolCallContent := bufStr[startIdx : startIdx+endIdx] + + if toolCall, err := p.parseToolCallContent(toolCallContent); err == nil { + remaining := bufStr[startIdx+endIdx+len(deepseekToolCallEndTag):] + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + + p.buffer.Reset() + p.buffer.WriteString(remaining) + + events = append(events, deepseekEventToolCall{toolCall: toolCall}) + return events, true + } else { + slog.Warn("deepseek tool call parsing failed", "error", err) + } + } + } + + if idx := strings.Index(bufStr, deepseekToolCallsEndTag); idx != -1 { + remaining := bufStr[idx+len(deepseekToolCallsEndTag):] + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = DeepSeekCollectingContent + + return events, true + } + + return events, false + + case DeepSeekCollectingToolOutput: + if idx := strings.Index(bufStr, deepseekToolOutputEndTag); idx != -1 { + toolOutputContent := bufStr[:idx] + remaining := bufStr[idx+len(deepseekToolOutputEndTag):] + // Don't trim whitespace - preserve spaces after tool output tags + + p.buffer.Reset() + p.buffer.WriteString(remaining) + p.state = DeepSeekCollectingContent + + if len(toolOutputContent) > 0 { + events = append(events, deepseekEventContent{content: toolOutputContent}) + } + return events, true + } + + return events, false + } + + return events, false +} + +func (p *DeepSeek3Parser) parseToolCallContent(content string) (api.ToolCall, error) { + // Expected format: tool_name<|tool▁sep|>{args} + parts := strings.SplitN(content, deepseekToolSepTag, 2) + if len(parts) < 2 { + return api.ToolCall{}, errors.New("invalid format") + } + + toolName := strings.TrimSpace(parts[0]) + argsJSON := strings.TrimSpace(parts[1]) + + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { + return api.ToolCall{}, err + } + + return api.ToolCall{ + Function: api.ToolCallFunction{ + Name: toolName, + Arguments: args, + }, + }, nil +} diff --git a/model/parsers/deepseek3_test.go b/model/parsers/deepseek3_test.go new file mode 100644 index 00000000..4e3180d4 --- /dev/null +++ b/model/parsers/deepseek3_test.go @@ -0,0 +1,721 @@ +package parsers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestDeepSeekParser(t *testing.T) { + tests := []struct { + name string + input string + expectedContent string + expectedThinking string + expectedCalls []api.ToolCall + hasThinking bool + }{ + { + name: "simple_content", + input: "Hello, how are you?", + expectedContent: "Hello, how are you?", + hasThinking: false, + }, + { + name: "thinking_content", + input: "I need to think about this...The answer is 42.", + expectedThinking: "I need to think about this...", + expectedContent: "The answer is 42.", + hasThinking: true, + }, + { + name: "no_thinking_simple", + input: "Just a regular response.", + expectedContent: "Just a regular response.", + hasThinking: false, + }, + { + name: "thinking_with_newlines", + input: "Let me think:\n- Point 1\n- Point 2\n\nHere's my answer.", + expectedThinking: "Let me think:\n- Point 1\n- Point 2", + expectedContent: "Here's my answer.", + hasThinking: true, + }, + { + name: "tool_call_simple", + input: "I'll check the weather.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>", + expectedContent: "I'll check the weather.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + hasThinking: false, + }, + { + name: "multiple_tool_calls", + input: "Getting weather for both cities.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"Paris\"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"London\"}<|tool▁call▁end|><|tool▁calls▁end|>", + expectedContent: "Getting weather for both cities.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "London", + }, + }, + }, + }, + hasThinking: false, + }, + { + name: "tool_output", + input: "Here's the weather: <|tool▁output▁begin|>Temperature: 22°C, Sunny<|tool▁output▁end|> Hope that helps!", + expectedContent: "Here's the weather: Temperature: 22°C, Sunny Hope that helps!", + hasThinking: false, + }, + { + name: "complex_tool_arguments", + input: "Processing data.<|tool▁calls▁begin|><|tool▁call▁begin|>process_data<|tool▁sep|>{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}<|tool▁call▁end|><|tool▁calls▁end|>", + expectedContent: "Processing data.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "process_data", + Arguments: api.ToolCallFunctionArguments{ + "items": []interface{}{"item1", "item2"}, + "config": map[string]interface{}{"enabled": true, "threshold": 0.95}, + }, + }, + }, + }, + hasThinking: false, + }, + { + name: "thinking_with_tool_call", // technically this can't happen, but the parser can handle it + input: "Let me check the weather...I'll get that for you.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>", + expectedThinking: "Let me check the weather...", + expectedContent: "I'll get that for you.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + hasThinking: true, + }, + { + name: "empty_content", + input: "", + expectedContent: "", + hasThinking: false, + }, + { + name: "only_thinking", + input: "Just thinking content", + expectedThinking: "Just thinking content", + expectedContent: "", + hasThinking: true, + }, + { + name: "multiple_tool_outputs", + input: "Results: <|tool▁output▁begin|>Paris: 22°C<|tool▁output▁end|> and <|tool▁output▁begin|>London: 18°C<|tool▁output▁end|>", + expectedContent: "Results: Paris: 22°C and London: 18°C", + hasThinking: false, + }, + { + name: "unicode_content", + input: "مرحبا بالعالم! 你好世界! 🌍", + expectedContent: "مرحبا بالعالم! 你好世界! 🌍", + hasThinking: false, + }, + { + name: "emoji_passthrough", + input: "Task completed ✅ 🎉", + expectedContent: "Task completed ✅ 🎉", + hasThinking: false, + }, + { + name: "emoji_after_tool_call", + input: "I'll help you.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"location\":\"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>完成 ✅", + expectedContent: "I'll help you.完成 ✅", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo", + }, + }, + }, + }, + hasThinking: false, + }, + { + name: "newlines_and_whitespace", + input: "Line 1\n\nLine 3\t\tTabbed content", + expectedContent: "Line 1\n\nLine 3\t\tTabbed content", + hasThinking: false, + }, + { + name: "thinking_with_unicode", + input: "我在思考这个问题...答案是42。", + expectedThinking: "我在思考这个问题...", + expectedContent: "答案是42。", + hasThinking: true, + }, + { + name: "tool_call_with_unicode_args", + input: "Searching for information.<|tool▁calls▁begin|><|tool▁call▁begin|>search<|tool▁sep|>{\"query\":\"北京天气\",\"language\":\"中文\"}<|tool▁call▁end|><|tool▁calls▁end|>", + expectedContent: "Searching for information.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "search", + Arguments: api.ToolCallFunctionArguments{ + "query": "北京天气", + "language": "中文", + }, + }, + }, + }, + hasThinking: false, + }, + { + name: "tool_output_with_unicode", + input: "天气信息: <|tool▁output▁begin|>北京: 25°C, 晴天<|tool▁output▁end|> 希望对您有帮助!", + expectedContent: "天气信息: 北京: 25°C, 晴天 希望对您有帮助!", + hasThinking: false, + }, + { + name: "mixed_content_with_special_chars", + input: "Price: $100 & tax @ 10% = $110 <|tool▁output▁begin|>Total: $110<|tool▁output▁end|> (final)", + expectedContent: "Price: $100 & tax @ 10% = $110 Total: $110 (final)", + hasThinking: false, + }, + { + name: "tool_call_with_special_chars", + input: "Processing data.<|tool▁calls▁begin|><|tool▁call▁begin|>execute_command<|tool▁sep|>{\"command\":\"ls && echo \\\"done\\\"\",\"path\":\"/home/user\"}<|tool▁call▁end|><|tool▁calls▁end|>", + expectedContent: "Processing data.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "execute_command", + Arguments: api.ToolCallFunctionArguments{ + "command": "ls && echo \"done\"", + "path": "/home/user", + }, + }, + }, + }, + hasThinking: false, + }, + { + name: "thinking_with_special_chars", + input: "Let me calculate: 2+2=4 & 3*3=9...The results are correct!", + expectedThinking: "Let me calculate: 2+2=4 & 3*3=9...", + expectedContent: "The results are correct!", + hasThinking: true, + }, + { + name: "empty_tool_call_args", + input: "Pinging server.<|tool▁calls▁begin|><|tool▁call▁begin|>ping<|tool▁sep|>{}<|tool▁call▁end|><|tool▁calls▁end|>", + expectedContent: "Pinging server.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "ping", + Arguments: api.ToolCallFunctionArguments{}, + }, + }, + }, + hasThinking: false, + }, + { + name: "empty_tool_output", + input: "Checking status: <|tool▁output▁begin|><|tool▁output▁end|> No output received.", + expectedContent: "Checking status: No output received.", + hasThinking: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &DeepSeek3Parser{hasThinkingSupport: tt.hasThinking} + parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking}) + + content, thinking, calls, err := parser.Add(tt.input, true) + if err != nil { + t.Fatalf("Add() error = %v", err) + } + + if diff := cmp.Diff(tt.expectedContent, content); diff != "" { + t.Errorf("Content mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" { + t.Errorf("Thinking mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedCalls, calls); diff != "" { + t.Errorf("Tool calls mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestDeepSeekParser_Streaming(t *testing.T) { + tests := []struct { + name string + chunks []string + expectedContent string + expectedThinking string + expectedCalls []api.ToolCall + hasThinking bool + }{ + { + name: "streaming_simple_content", + chunks: []string{"Hello, ", "how are ", "you?"}, + expectedContent: "Hello, how are you?", + hasThinking: false, + }, + { + name: "streaming_thinking", + chunks: []string{"I need to ", "think about this", "...", "The answer is 42."}, + expectedThinking: "I need to think about this...", + expectedContent: "The answer is 42.", + hasThinking: true, + }, + { + name: "streaming_tool_call", + chunks: []string{"I'll check weather.", "<|tool▁calls▁begin|>", "<|tool▁call▁begin|>get_weather", "<|tool▁sep|>{\"location\":\"Paris\"}", "<|tool▁call▁end|><|tool▁calls▁end|>"}, + expectedContent: "I'll check weather.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + hasThinking: false, + }, + { + name: "streaming_thinking_with_partial_tag", + chunks: []string{"Thinking about this", "...", "Done thinking."}, + expectedThinking: "Thinking about this...", + expectedContent: "Done thinking.", + hasThinking: true, + }, + { + name: "streaming_tool_output", + chunks: []string{"Weather info: ", "<|tool▁output▁begin|>", "25°C, Sunny", "<|tool▁output▁end|>", " Enjoy!"}, + expectedContent: "Weather info: 25°C, Sunny Enjoy!", + hasThinking: false, + }, + { + name: "streaming_with_split_tags", + chunks: []string{"Content before ", "<|tool▁calls▁begin|><|tool▁call▁begin|>test", "<|tool▁sep|>{}", "<|tool▁call▁end|><|tool▁calls▁end|>", " after"}, + expectedContent: "Content before after", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test", + Arguments: api.ToolCallFunctionArguments{}, + }, + }, + }, + hasThinking: false, + }, + { + name: "streaming_thinking_with_split_end_tag", + chunks: []string{"Thinking content", "", "Regular content"}, + expectedThinking: "Thinking content", + expectedContent: "Regular content", + hasThinking: true, + }, + { + name: "streaming_unicode_content", + chunks: []string{"مرحبا ", "بالعالم! ", "你好", "世界!"}, + expectedContent: "مرحبا بالعالم! 你好世界!", + hasThinking: false, + }, + { + name: "streaming_multiple_tool_outputs", + chunks: []string{"Results: ", "<|tool▁output▁begin|>", "Paris: 22°C", "<|tool▁output▁end|>", " and ", "<|tool▁output▁begin|>", "London: 18°C", "<|tool▁output▁end|>"}, + expectedContent: "Results: Paris: 22°C and London: 18°C", + hasThinking: false, + }, + { + name: "streaming_tool_call_with_split_json", + chunks: []string{"Processing.", "<|tool▁calls▁begin|><|tool▁call▁begin|>calc<|tool▁sep|>{\"x\":", "42,\"y\":", "24}<|tool▁call▁end|><|tool▁calls▁end|>"}, + expectedContent: "Processing.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "calc", + Arguments: api.ToolCallFunctionArguments{ + "x": float64(42), + "y": float64(24), + }, + }, + }, + }, + hasThinking: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &DeepSeek3Parser{hasThinkingSupport: tt.hasThinking} + parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking}) + + var allContent, allThinking string + var allCalls []api.ToolCall + + for i, chunk := range tt.chunks { + done := i == len(tt.chunks)-1 + content, thinking, calls, err := parser.Add(chunk, done) + if err != nil { + t.Fatalf("Add() error = %v", err) + } + + allContent += content + allThinking += thinking + allCalls = append(allCalls, calls...) + } + + if diff := cmp.Diff(tt.expectedContent, allContent); diff != "" { + t.Errorf("Content mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedThinking, allThinking); diff != "" { + t.Errorf("Thinking mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedCalls, allCalls); diff != "" { + t.Errorf("Tool calls mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestDeepSeekParser_HasThinkingSupport(t *testing.T) { + tests := []struct { + name string + hasThinking bool + expectedSupport bool + }{ + { + name: "thinking_enabled", + hasThinking: true, + expectedSupport: true, + }, + { + name: "thinking_disabled", + hasThinking: false, + expectedSupport: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &DeepSeek3Parser{hasThinkingSupport: tt.hasThinking} + if got := parser.HasThinkingSupport(); got != tt.expectedSupport { + t.Errorf("HasThinkingSupport() = %v, want %v", got, tt.expectedSupport) + } + }) + } +} + +func TestDeepSeekParser_HasToolSupport(t *testing.T) { + parser := &DeepSeek3Parser{} + if !parser.HasToolSupport() { + t.Error("HasToolSupport() should return true") + } +} + +func TestDeepSeekParser_Init(t *testing.T) { + parser := &DeepSeek3Parser{hasThinkingSupport: true} + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "test_tool", + }, + }, + } + + returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true}) + + if diff := cmp.Diff(tools, returnedTools); diff != "" { + t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff) + } + + // Test initial state is set to thinking when enabled + if parser.state != DeepSeekCollectingThinking { + t.Errorf("Expected initial state to be DeepSeekCollectingThinking, got %v", parser.state) + } +} + +func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) { + tests := []struct { + name string + content string + expected api.ToolCall + expectError bool + }{ + { + name: "valid_tool_call", + content: "get_weather<|tool▁sep|>{\"location\":\"Paris\"}", + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + { + name: "complex_arguments", + content: "process_data<|tool▁sep|>{\"items\":[\"a\",\"b\"],\"config\":{\"enabled\":true}}", + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "process_data", + Arguments: api.ToolCallFunctionArguments{ + "items": []interface{}{"a", "b"}, + "config": map[string]interface{}{"enabled": true}, + }, + }, + }, + }, + { + name: "empty_arguments", + content: "ping<|tool▁sep|>{}", + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "ping", + Arguments: api.ToolCallFunctionArguments{}, + }, + }, + }, + { + name: "unicode_in_tool_name", + content: "获取天气<|tool▁sep|>{\"城市\":\"北京\"}", + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "获取天气", + Arguments: api.ToolCallFunctionArguments{ + "城市": "北京", + }, + }, + }, + }, + { + name: "special_chars_in_arguments", + content: "execute<|tool▁sep|>{\"command\":\"ls && echo \\\"done\\\"\",\"path\":\"/home/user\"}", + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "execute", + Arguments: api.ToolCallFunctionArguments{ + "command": "ls && echo \"done\"", + "path": "/home/user", + }, + }, + }, + }, + { + name: "numeric_arguments", + content: "calculate<|tool▁sep|>{\"x\":3.14,\"y\":42,\"enabled\":true}", + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "calculate", + Arguments: api.ToolCallFunctionArguments{ + "x": 3.14, + "y": float64(42), + "enabled": true, + }, + }, + }, + }, + { + name: "invalid_format_no_separator", + content: "get_weather{\"location\":\"Paris\"}", + expectError: true, + }, + { + name: "invalid_json", + content: "get_weather<|tool▁sep|>{invalid json}", + expectError: true, + }, + { + name: "empty_tool_name", + content: "<|tool▁sep|>{\"arg\":\"value\"}", + expectError: false, // This should work, just empty name + expected: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "", + Arguments: api.ToolCallFunctionArguments{ + "arg": "value", + }, + }, + }, + }, + { + name: "missing_json_part", + content: "tool_name<|tool▁sep|>", + expectError: true, + }, + } + + parser := &DeepSeek3Parser{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parser.parseToolCallContent(tt.content) + + if tt.expectError { + if err == nil { + t.Error("Expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if diff := cmp.Diff(tt.expected, result); diff != "" { + t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestDeepSeekParser_EdgeCases(t *testing.T) { + tests := []struct { + name string + input string + expectedContent string + expectedThinking string + hasThinking bool + }{ + { + name: "nested_think_tags_in_thinking", + input: "Outer thinking inner contentFinal content", + expectedThinking: "Outer thinking inner", + expectedContent: "contentFinal content", + hasThinking: true, + }, + { + name: "multiple_think_close_tags", + input: "First thoughtSecond thoughtFinal content", + expectedThinking: "First thought", + expectedContent: "Second thoughtFinal content", + hasThinking: true, + }, + { + name: "empty_thinking_content", + input: "Just content", + expectedThinking: "", + expectedContent: "Just content", + hasThinking: true, + }, + { + name: "thinking_disabled_with_think_tags", + input: "Some contentMore content", + expectedContent: "Some contentMore content", + hasThinking: false, + }, + { + name: "malformed_tool_call_missing_sep", + input: "Testing.<|tool▁calls▁begin|><|tool▁call▁begin|>bad_tool{\"arg\":\"value\"}<|tool▁call▁end|><|tool▁calls▁end|>", + expectedContent: "Testing.", + hasThinking: false, + }, + { + name: "malformed_tool_call_invalid_json", + input: "Testing.<|tool▁calls▁begin|><|tool▁call▁begin|>bad_tool<|tool▁sep|>{invalid json}<|tool▁call▁end|><|tool▁calls▁end|>", + expectedContent: "Testing.", + hasThinking: false, + }, + { + name: "partial_tool_tag_at_end", + input: "Content with partial <|tool▁calls▁", + expectedContent: "Content with partial <|tool▁calls▁", + hasThinking: false, + }, + { + name: "partial_think_tag_at_end", + input: "Thinking contentLine 1\nLine 2\nLine 3<|tool▁output▁end|>\nDone.", + expectedContent: "Output:\nLine 1\nLine 2\nLine 3\nDone.", + hasThinking: false, + }, + { + name: "consecutive_tool_calls", + input: "First.<|tool▁calls▁begin|><|tool▁call▁begin|>tool1<|tool▁sep|>{}<|tool▁call▁end|><|tool▁calls▁end|>Second.<|tool▁calls▁begin|><|tool▁call▁begin|>tool2<|tool▁sep|>{}<|tool▁call▁end|><|tool▁calls▁end|>", + expectedContent: "First.", + hasThinking: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &DeepSeek3Parser{hasThinkingSupport: tt.hasThinking} + parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking}) + + content, thinking, _, err := parser.Add(tt.input, true) + if err != nil { + t.Fatalf("Add() error = %v", err) + } + + if diff := cmp.Diff(tt.expectedContent, content); diff != "" { + t.Errorf("Content mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" { + t.Errorf("Thinking mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/model/parsers/functiongemma.go b/model/parsers/functiongemma.go new file mode 100644 index 00000000..35f8791c --- /dev/null +++ b/model/parsers/functiongemma.go @@ -0,0 +1,323 @@ +package parsers + +import ( + "fmt" + "regexp" + "strings" + + "github.com/ollama/ollama/api" +) + +type FunctionGemmaParserState int + +const ( + FunctionGemmaCollectingContent FunctionGemmaParserState = iota + FunctionGemmaCollectingToolCalls +) + +const ( + functionGemmaFunctionCallOpen = "" + functionGemmaFunctionCallClose = "" +) + +// This format uses call:name{args} for tool calls. +type FunctionGemmaParser struct { + state FunctionGemmaParserState + buffer strings.Builder + tools []api.Tool +} + +func (p *FunctionGemmaParser) HasToolSupport() bool { return true } +func (p *FunctionGemmaParser) HasThinkingSupport() bool { return false } + +func (p *FunctionGemmaParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.tools = tools + p.state = FunctionGemmaCollectingContent + return tools +} + +type functionGemmaEvent interface { + isFunctionGemmaEvent() +} + +type FunctionGemmaEventContent struct { + content string +} + +type functionGemmaEventToolCall struct { + toolCall api.ToolCall +} + +func (FunctionGemmaEventContent) isFunctionGemmaEvent() {} +func (functionGemmaEventToolCall) isFunctionGemmaEvent() {} + +func (p *FunctionGemmaParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + events := p.parseEvents() + + var toolCalls []api.ToolCall + var contentSb strings.Builder + for _, event := range events { + switch event := event.(type) { + case functionGemmaEventToolCall: + toolCalls = append(toolCalls, event.toolCall) + case FunctionGemmaEventContent: + contentSb.WriteString(event.content) + } + } + + return contentSb.String(), "", toolCalls, nil +} + +func (p *FunctionGemmaParser) parseEvents() []functionGemmaEvent { + var all []functionGemmaEvent + + keepLooping := true + for keepLooping { + var events []functionGemmaEvent + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + return all +} + +// emitWithPartialCheck extracts unambiguous content before a potential partial tag +func (p *FunctionGemmaParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) { + if overlapLen := overlap(bufStr, tag); overlapLen > 0 { + beforePartialTag := bufStr[:len(bufStr)-overlapLen] + return beforePartialTag, bufStr[len(beforePartialTag):] + } + return bufStr, "" +} + +func (p *FunctionGemmaParser) eat() ([]functionGemmaEvent, bool) { + bufStr := p.buffer.String() + if bufStr == "" { + return nil, false + } + + switch p.state { + case FunctionGemmaCollectingContent: + if strings.Contains(bufStr, functionGemmaFunctionCallOpen) { + split := strings.SplitN(bufStr, functionGemmaFunctionCallOpen, 2) + content := split[0] + p.buffer.Reset() + p.buffer.WriteString(split[1]) + p.state = FunctionGemmaCollectingToolCalls + if content != "" { + return []functionGemmaEvent{FunctionGemmaEventContent{content: content}}, true + } + return nil, true + } + unambig, ambig := p.emitWithPartialCheck(bufStr, functionGemmaFunctionCallOpen) + p.buffer.Reset() + p.buffer.WriteString(ambig) + if unambig != "" { + return []functionGemmaEvent{FunctionGemmaEventContent{content: unambig}}, false + } + return nil, false + + case FunctionGemmaCollectingToolCalls: + if strings.Contains(bufStr, functionGemmaFunctionCallClose) { + split := strings.SplitN(bufStr, functionGemmaFunctionCallClose, 2) + remaining := split[1] + p.buffer.Reset() + p.buffer.WriteString(remaining) + + var events []functionGemmaEvent + if tc, err := p.parseToolCall(split[0]); err == nil { + events = append(events, functionGemmaEventToolCall{toolCall: tc}) + } + + if !strings.Contains(remaining, functionGemmaFunctionCallOpen) { + p.state = FunctionGemmaCollectingContent + } + return events, true + } + return nil, false + } + + return nil, false +} + +// Matches call:function_name{args} +var functionGemmaCallRegex = regexp.MustCompile(`call:([^{]+)\{(.*)\}`) + +func (p *FunctionGemmaParser) parseToolCall(content string) (api.ToolCall, error) { + toolCall := api.ToolCall{} + + // Extract function name and arguments + match := functionGemmaCallRegex.FindStringSubmatch(content) + if len(match) < 3 { + return toolCall, nil + } + + toolCall.Function.Name = match[1] + argsStr := match[2] + + // Parse arguments + toolCall.Function.Arguments = p.parseArguments(argsStr) + + return toolCall, nil +} + +// parseArguments parses the key:value,key:value format +func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctionArguments { + args := make(api.ToolCallFunctionArguments) + if argsStr == "" { + return args + } + + // Split by comma, but handle nested structures + parts := p.splitArguments(argsStr) + + for _, part := range parts { + // Find the first colon to split key:value + colonIdx := strings.Index(part, ":") + if colonIdx == -1 { + continue + } + + key := part[:colonIdx] + value := part[colonIdx+1:] + + // Parse the value + args[key] = p.parseValue(value) + } + + return args +} + +// splitArguments splits arguments by comma, respecting nested structures +func (p *FunctionGemmaParser) splitArguments(argsStr string) []string { + var parts []string + var current strings.Builder + depth := 0 + inEscape := false + + for i := 0; i < len(argsStr); i++ { + ch := argsStr[i] + + // Check for tags + if i+8 <= len(argsStr) && argsStr[i:i+8] == "" { + inEscape = !inEscape + current.WriteString("") + i += 7 // Skip the rest of + continue + } + + if !inEscape { + switch ch { + case '{', '[': + depth++ + current.WriteByte(ch) + case '}', ']': + depth-- + current.WriteByte(ch) + case ',': + if depth == 0 { + if current.Len() > 0 { + parts = append(parts, current.String()) + current.Reset() + } + continue + } + current.WriteByte(ch) + default: + current.WriteByte(ch) + } + } else { + current.WriteByte(ch) + } + } + + if current.Len() > 0 { + parts = append(parts, current.String()) + } + + return parts +} + +// parseValue parses a single value from the FunctionGemma format +func (p *FunctionGemmaParser) parseValue(value string) any { + // Check for escaped string + if strings.HasPrefix(value, "") && strings.HasSuffix(value, "") { + // Remove the escape tags + return value[8 : len(value)-8] + } + + // Check for boolean + if value == "true" { + return true + } + if value == "false" { + return false + } + + // Check for number + if num, ok := parseNumber(value); ok { + return num + } + + // Check for array + if strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") { + return p.parseArray(value[1 : len(value)-1]) + } + + // Check for object + if strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}") { + return p.parseObject(value[1 : len(value)-1]) + } + + // Default to string + return value +} + +// parseArray parses an array value +func (p *FunctionGemmaParser) parseArray(content string) []any { + var result []any + parts := p.splitArguments(content) + for _, part := range parts { + result = append(result, p.parseValue(part)) + } + return result +} + +// parseObject parses an object value +func (p *FunctionGemmaParser) parseObject(content string) map[string]any { + result := make(map[string]any) + parts := p.splitArguments(content) + for _, part := range parts { + colonIdx := strings.Index(part, ":") + if colonIdx == -1 { + continue + } + key := part[:colonIdx] + value := part[colonIdx+1:] + result[key] = p.parseValue(value) + } + return result +} + +// parseNumber tries to parse a string as a number +func parseNumber(s string) (any, bool) { + // Try integer first + var intVal int64 + if _, err := fmt.Sscanf(s, "%d", &intVal); err == nil { + // Check if the entire string was consumed + if fmt.Sprintf("%d", intVal) == s { + return intVal, true + } + } + + // Try float + var floatVal float64 + if _, err := fmt.Sscanf(s, "%f", &floatVal); err == nil { + return floatVal, true + } + + return nil, false +} diff --git a/model/parsers/functiongemma_test.go b/model/parsers/functiongemma_test.go new file mode 100644 index 00000000..227abdb8 --- /dev/null +++ b/model/parsers/functiongemma_test.go @@ -0,0 +1,423 @@ +package parsers + +import ( + "testing" + + "github.com/ollama/ollama/api" + "github.com/stretchr/testify/assert" +) + +func TestFunctionGemmaParser(t *testing.T) { + tests := []struct { + name string + chunks []string + tools []api.Tool + expectedCalls []api.ToolCall + expectedText string + }{ + { + name: "plain_content", + chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"}, + expectedCalls: nil, + expectedText: "Hello, world!", + }, + { + name: "simple_tool_call", + chunks: []string{ + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "get", "_", "weather", "{", + "city", ":", "<", "escape", ">", "Paris", "<", "escape", ">", + "}", "<", "end", "_", "function", "_", "call", ">", + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + }, + expectedText: "", + }, + { + name: "content_before_tool_call", + chunks: []string{ + "L", "et", " ", "me", " ", "check", ".", + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "get", "_", "weather", "{", + "city", ":", "<", "escape", ">", "Paris", "<", "escape", ">", + "}", "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + }, + expectedText: "Let me check.", + }, + { + name: "numeric_arguments", + chunks: []string{ + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "add", "{", + "a", ":", "1", ",", "b", ":", "2", + "}", "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "add", + Arguments: api.ToolCallFunctionArguments{"a": int64(1), "b": int64(2)}, + }, + }, + }, + expectedText: "", + }, + { + name: "boolean_arguments", + chunks: []string{ + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "set", "_", "flag", "{", + "enabled", ":", "true", ",", "verbose", ":", "false", + "}", "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "set_flag", + Arguments: api.ToolCallFunctionArguments{"enabled": true, "verbose": false}, + }, + }, + }, + expectedText: "", + }, + { + name: "multiple_tool_calls", + chunks: []string{ + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "get", "_", "weather", "{", + "city", ":", "<", "escape", ">", "Paris", "<", "escape", ">", + "}", "<", "end", "_", "function", "_", "call", ">", + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "get", "_", "weather", "{", + "city", ":", "<", "escape", ">", "London", "<", "escape", ">", + "}", "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "London"}, + }, + }, + }, + expectedText: "", + }, + { + name: "array_argument", + chunks: []string{ + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "process", "{", + "items", ":", "[", + "<", "escape", ">", "a", "<", "escape", ">", ",", + "<", "escape", ">", "b", "<", "escape", ">", ",", + "<", "escape", ">", "c", "<", "escape", ">", + "]", + "}", "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "process", + Arguments: api.ToolCallFunctionArguments{"items": []any{"a", "b", "c"}}, + }, + }, + }, + expectedText: "", + }, + { + name: "object_argument", + chunks: []string{ + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "update", "{", + "data", ":", "{", + "name", ":", "<", "escape", ">", "test", "<", "escape", ">", ",", + "value", ":", "42", + "}", + "}", "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "update", + Arguments: api.ToolCallFunctionArguments{ + "data": map[string]any{"name": "test", "value": int64(42)}, + }, + }, + }, + }, + expectedText: "", + }, + { + name: "empty_input", + chunks: []string{}, + expectedCalls: nil, + expectedText: "", + }, + { + name: "tool_call_with_no_arguments", + chunks: []string{ + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "get", "_", "time", "{", "}", + "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_time", + Arguments: api.ToolCallFunctionArguments{}, + }, + }, + }, + expectedText: "", + }, + { + name: "content_with_angle_brackets", + chunks: []string{ + "The", " ", "result", " ", "is", " ", "a", " ", "<", "value", ">", " ", "tag", + }, + expectedCalls: nil, + expectedText: "The result is a tag", + }, + { + name: "float_argument", + chunks: []string{ + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "set", "_", "temp", "{", + "value", ":", "3", ".", "14", + "}", "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "set_temp", + Arguments: api.ToolCallFunctionArguments{"value": 3.14}, + }, + }, + }, + expectedText: "", + }, + { + name: "content_after_tool_call", + chunks: []string{ + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "test", "{", "}", + "<", "end", "_", "function", "_", "call", ">", + "Done", "!", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test", + Arguments: api.ToolCallFunctionArguments{}, + }, + }, + }, + expectedText: "Done!", + }, + { + name: "unicode_content_and_arguments", + chunks: []string{ + "こんにちは", " ", + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "greet", "{", + "name", ":", "<", "escape", ">", "日本語", "<", "escape", ">", + "}", "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "greet", + Arguments: api.ToolCallFunctionArguments{"name": "日本語"}, + }, + }, + }, + expectedText: "こんにちは ", + }, + { + name: "multiple_params_sorted", + chunks: []string{ + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "search", "{", + "query", ":", "<", "escape", ">", "test", "<", "escape", ">", ",", + "limit", ":", "10", ",", + "offset", ":", "0", + "}", "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "search", + Arguments: api.ToolCallFunctionArguments{ + "query": "test", + "limit": int64(10), + "offset": int64(0), + }, + }, + }, + }, + expectedText: "", + }, + { + name: "nested_object_argument", + chunks: []string{ + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "create", "{", + "config", ":", "{", + "settings", ":", "{", + "enabled", ":", "true", ",", + "name", ":", "<", "escape", ">", "test", "<", "escape", ">", + "}", + "}", + "}", "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "create", + Arguments: api.ToolCallFunctionArguments{ + "config": map[string]any{ + "settings": map[string]any{ + "enabled": true, + "name": "test", + }, + }, + }, + }, + }, + }, + expectedText: "", + }, + { + name: "partial_start_tag_in_content", + chunks: []string{ + "Hello", " ", "<", "start", " ", "world", + }, + expectedCalls: nil, + expectedText: "Hello ", + "call", ":", "get", "_", "weather", "{", + "city", ":", "<", "escape", ">", "Paris", "<", "escape", ">", + "}", "<", "end", "_", "function", "_", "call", ">", + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "get", "_", "time", "{", + "timezone", ":", "<", "escape", ">", "UTC", "<", "escape", ">", + "}", "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_time", + Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"}, + }, + }, + }, + expectedText: "", + }, + { + name: "content_between_tool_calls", + chunks: []string{ + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "first", "{", "}", + "<", "end", "_", "function", "_", "call", ">", + "Some", " ", "text", " ", "here", + "<", "start", "_", "function", "_", "call", ">", + "call", ":", "second", "{", "}", + "<", "end", "_", "function", "_", "call", ">", + }, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "first", + Arguments: api.ToolCallFunctionArguments{}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "second", + Arguments: api.ToolCallFunctionArguments{}, + }, + }, + }, + expectedText: "Some text here", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &FunctionGemmaParser{} + parser.Init(tt.tools, nil, nil) + + var allContent string + var allCalls []api.ToolCall + + for i, chunk := range tt.chunks { + done := i == len(tt.chunks)-1 + content, _, calls, err := parser.Add(chunk, done) + assert.NoError(t, err) + allContent += content + allCalls = append(allCalls, calls...) + } + + // Handle empty chunks case + if len(tt.chunks) == 0 { + content, _, calls, err := parser.Add("", true) + assert.NoError(t, err) + allContent = content + allCalls = calls + } + + assert.Equal(t, tt.expectedText, allContent) + assert.Equal(t, tt.expectedCalls, allCalls) + }) + } +} + +func TestFunctionGemmaParser_HasSupport(t *testing.T) { + parser := &FunctionGemmaParser{} + assert.True(t, parser.HasToolSupport()) + assert.False(t, parser.HasThinkingSupport()) +} diff --git a/model/parsers/nemotron3nano.go b/model/parsers/nemotron3nano.go new file mode 100644 index 00000000..6e662fba --- /dev/null +++ b/model/parsers/nemotron3nano.go @@ -0,0 +1,254 @@ +package parsers + +import ( + "regexp" + "strings" + "unicode" + + "github.com/ollama/ollama/api" +) + +type Nemotron3NanoParserState int + +const ( + Nemotron3NanoCollectingThinking Nemotron3NanoParserState = iota + Nemotron3NanoSkipWhitespaceAfterThinking + Nemotron3NanoCollectingContent + Nemotron3NanoCollectingToolCalls +) + +const ( + nemotronThinkClose = "" + nemotronToolCallOpen = "" + nemotronToolCallClose = "" +) + +type Nemotron3NanoParser struct { + state Nemotron3NanoParserState + buffer strings.Builder + tools []api.Tool +} + +func (p *Nemotron3NanoParser) HasToolSupport() bool { return true } +func (p *Nemotron3NanoParser) HasThinkingSupport() bool { return true } + +func (p *Nemotron3NanoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.tools = tools + + // thinking is enabled if user requests it + thinkingEnabled := thinkValue != nil && thinkValue.Bool() + + prefill := lastMessage != nil && lastMessage.Role == "assistant" + + if !thinkingEnabled { + p.state = Nemotron3NanoCollectingContent + return tools + } + + if prefill && lastMessage.Content != "" { + p.state = Nemotron3NanoCollectingContent + return tools + } + + p.state = Nemotron3NanoCollectingThinking + return tools +} + +type nemotronEvent interface { + isNemotronEvent() +} + +type nemotronEventThinkingContent struct { + content string +} + +type nemotronEventContent struct { + content string +} + +type nemotronEventToolCall struct { + toolCall api.ToolCall +} + +func (nemotronEventThinkingContent) isNemotronEvent() {} +func (nemotronEventContent) isNemotronEvent() {} +func (nemotronEventToolCall) isNemotronEvent() {} + +func (p *Nemotron3NanoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + events := p.parseEvents() + + var toolCalls []api.ToolCall + var contentSb strings.Builder + var thinkingSb strings.Builder + for _, event := range events { + switch event := event.(type) { + case nemotronEventToolCall: + toolCalls = append(toolCalls, event.toolCall) + case nemotronEventThinkingContent: + thinkingSb.WriteString(event.content) + case nemotronEventContent: + contentSb.WriteString(event.content) + } + } + + return contentSb.String(), thinkingSb.String(), toolCalls, nil +} + +func (p *Nemotron3NanoParser) parseEvents() []nemotronEvent { + var all []nemotronEvent + + keepLooping := true + for keepLooping { + var events []nemotronEvent + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + return all +} + +// emitWithPartialCheck extracts unambiguous content before a potential partial tag +func (p *Nemotron3NanoParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) { + if overlapLen := overlap(bufStr, tag); overlapLen > 0 { + beforePartialTag := bufStr[:len(bufStr)-overlapLen] + trailingLen := trailingWhitespaceLen(beforePartialTag) + return bufStr[:len(beforePartialTag)-trailingLen], bufStr[len(beforePartialTag)-trailingLen:] + } + wsLen := trailingWhitespaceLen(bufStr) + return bufStr[:len(bufStr)-wsLen], bufStr[len(bufStr)-wsLen:] +} + +func (p *Nemotron3NanoParser) eat() ([]nemotronEvent, bool) { + bufStr := p.buffer.String() + if bufStr == "" { + return nil, false + } + + switch p.state { + case Nemotron3NanoCollectingThinking: + if strings.Contains(bufStr, nemotronThinkClose) { + split := strings.SplitN(bufStr, nemotronThinkClose, 2) + thinking := strings.TrimRightFunc(split[0], unicode.IsSpace) + p.buffer.Reset() + remainder := strings.TrimLeftFunc(split[1], unicode.IsSpace) + p.buffer.WriteString(remainder) + // Transition to whitespace-skipping state if buffer is empty, + // otherwise go directly to content collection + if remainder == "" { + p.state = Nemotron3NanoSkipWhitespaceAfterThinking + } else { + p.state = Nemotron3NanoCollectingContent + } + if thinking != "" { + return []nemotronEvent{nemotronEventThinkingContent{content: thinking}}, true + } + return nil, true + } + unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronThinkClose) + p.buffer.Reset() + p.buffer.WriteString(ambig) + if unambig != "" { + return []nemotronEvent{nemotronEventThinkingContent{content: unambig}}, false + } + return nil, false + + // We only want to skip whitespace between thinking and content + case Nemotron3NanoSkipWhitespaceAfterThinking: + bufStr = strings.TrimLeftFunc(bufStr, unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(bufStr) + if bufStr == "" { + return nil, false + } + p.state = Nemotron3NanoCollectingContent + return nil, true + + case Nemotron3NanoCollectingContent: + if strings.Contains(bufStr, nemotronToolCallOpen) { + split := strings.SplitN(bufStr, nemotronToolCallOpen, 2) + content := strings.TrimRightFunc(split[0], unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(split[1]) + p.state = Nemotron3NanoCollectingToolCalls + if content != "" { + return []nemotronEvent{nemotronEventContent{content: content}}, true + } + return nil, true + } + unambig, ambig := p.emitWithPartialCheck(bufStr, nemotronToolCallOpen) + p.buffer.Reset() + p.buffer.WriteString(ambig) + if unambig != "" { + return []nemotronEvent{nemotronEventContent{content: unambig}}, false + } + return nil, false + + case Nemotron3NanoCollectingToolCalls: + if strings.Contains(bufStr, nemotronToolCallClose) { + split := strings.SplitN(bufStr, nemotronToolCallClose, 2) + remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(remaining) + + var events []nemotronEvent + if tc, err := p.parseToolCall(split[0]); err == nil { + events = append(events, nemotronEventToolCall{toolCall: tc}) + } + + if !strings.Contains(remaining, nemotronToolCallOpen) { + p.state = Nemotron3NanoCollectingContent + } + return events, true + } + return nil, false + } + + return nil, false +} + +var ( + nemotronFunctionRegex = regexp.MustCompile(`]+)>`) + nemotronParameterRegex = regexp.MustCompile(`]+)>\n?([\s\S]*?)\n?`) +) + +func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error) { + toolCall := api.ToolCall{} + + // Extract function name + fnMatch := nemotronFunctionRegex.FindStringSubmatch(content) + if len(fnMatch) < 2 { + return toolCall, nil + } + toolCall.Function.Name = fnMatch[1] + + // Extract parameters + toolCall.Function.Arguments = make(api.ToolCallFunctionArguments) + paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1) + for _, match := range paramMatches { + if len(match) >= 3 { + paramName := match[1] + paramValue := strings.TrimSpace(match[2]) + + // Try to parse as typed value based on tool definition + toolCall.Function.Arguments[paramName] = p.parseParamValue(paramName, paramValue) + } + } + + return toolCall, nil +} + +func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any { + // Find the matching tool to get parameter type + var paramType api.PropertyType + for _, tool := range p.tools { + if prop, ok := tool.Function.Parameters.Properties[paramName]; ok { + paramType = prop.Type + break + } + } + + return parseValue(raw, paramType) +} diff --git a/model/parsers/nemotron3nano_test.go b/model/parsers/nemotron3nano_test.go new file mode 100644 index 00000000..a4517fc4 --- /dev/null +++ b/model/parsers/nemotron3nano_test.go @@ -0,0 +1,574 @@ +package parsers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestNemotron3NanoParser(t *testing.T) { + tests := []struct { + name string + input string + thinkValue *api.ThinkValue + expectedContent string + expectedThinking string + expectedCalls []api.ToolCall + }{ + { + name: "simple content - no thinking", + input: "Hello, how can I help you?", + thinkValue: nil, + expectedContent: "Hello, how can I help you?", + }, + { + name: "simple content - thinking disabled", + input: "Hello, how can I help you?", + thinkValue: &api.ThinkValue{Value: false}, + expectedContent: "Hello, how can I help you?", + }, + { + name: "thinking then content", + input: "Let me think about this...\nHere is my answer.", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Let me think about this...", + expectedContent: "Here is my answer.", + }, + { + name: "thinking with newlines", + input: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude\nThe answer is 42.", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Step 1: Analyze\nStep 2: Process\nStep 3: Conclude", + expectedContent: "The answer is 42.", + }, + { + name: "simple tool call", + input: "\n\n\nParis\n\n\n", + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "Paris"}, + }, + }, + }, + }, + { + name: "content then tool call", + input: "Let me check the weather.\n\n\n\nNYC\n\n\n", + thinkValue: nil, + expectedContent: "Let me check the weather.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "NYC"}, + }, + }, + }, + }, + { + name: "tool call with multiple parameters", + input: "\n\n\nSFO\n\n\nNYC\n\n\n", + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "book_flight", + Arguments: map[string]any{ + "from": "SFO", + "to": "NYC", + }, + }, + }, + }, + }, + { + name: "multiple tool calls", + input: "\n\n\nSan Francisco\n\n\n\n" + + "\n\n\nNew York\n\n\n", + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "San Francisco"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "New York"}, + }, + }, + }, + }, + { + name: "thinking then tool call", + input: "I should check the weather...\n\n\n\nParis\n\n\n", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "I should check the weather...", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "Paris"}, + }, + }, + }, + }, + { + name: "thinking content then tool call", + input: "Let me think...\nI'll check for you.\n\n\n\ntest\n\n\n", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Let me think...", + expectedContent: "I'll check for you.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "search", + Arguments: map[string]any{"query": "test"}, + }, + }, + }, + }, + { + name: "tool call with multiline parameter value", + input: "\n\n\nLine 1\nLine 2\nLine 3\n\n\n", + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "create_note", + Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"}, + }, + }, + }, + }, + { + name: "empty thinking block - immediate close", + input: "\nHere is my answer.", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "", + expectedContent: "Here is my answer.", + }, + { + name: "thinking disabled but model outputs think close anyway", + input: "\nSome content after spurious tag.", + thinkValue: &api.ThinkValue{Value: false}, + expectedContent: "\nSome content after spurious tag.", + }, + { + name: "tool call with no function name - returns empty tool call", + input: "\n\n\n", + thinkValue: nil, + expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: nil}}}, + }, + { + name: "content with newlines preserved", + input: "Line 1\n\nLine 2\n\n\nLine 3", + thinkValue: nil, + expectedContent: "Line 1\n\nLine 2\n\n\nLine 3", + }, + { + name: "thinking with only whitespace after close tag", + input: "My thoughts... \n\t\n Content here.", + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "My thoughts...", + expectedContent: "Content here.", + }, + { + name: "unicode content", + input: "Hello 世界! 🌍 Ñoño", + thinkValue: nil, + expectedContent: "Hello 世界! 🌍 Ñoño", + }, + { + name: "tool call with numeric parameter", + input: "\n\n\n42\n\n\n", + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "set_temp", + Arguments: map[string]any{"value": "42"}, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Nemotron3NanoParser{} + p.Init(nil, nil, tt.thinkValue) + + content, thinking, calls, err := p.Add(tt.input, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Drain remaining content + finalContent, finalThinking, finalCalls, err := p.Add("", true) + if err != nil { + t.Fatalf("unexpected error on done: %v", err) + } + content += finalContent + thinking += finalThinking + calls = append(calls, finalCalls...) + + if diff := cmp.Diff(content, tt.expectedContent); diff != "" { + t.Errorf("content mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" { + t.Errorf("thinking mismatch (-got +want):\n%s", diff) + } + if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" { + t.Errorf("calls mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestNemotron3NanoParser_Streaming(t *testing.T) { + tests := []struct { + name string + chunks []string + thinkValue *api.ThinkValue + expectedContent string + expectedThinking string + expectedCalls []api.ToolCall + }{ + { + name: "streaming content character by character", + chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"}, + thinkValue: nil, + expectedContent: "Hello, world!", + }, + { + name: "streaming content small tokens", + chunks: []string{"Hel", "lo", ", ", "how ", "can", " I", " help", " you", " today", "?"}, + thinkValue: nil, + expectedContent: "Hello, how can I help you today?", + }, + { + name: "streaming thinking then content - granular", + chunks: []string{"Let", " me", " th", "ink", " about", " this", "...", "<", "/", "think", ">", "\n", "Here", " is", " my", " answer", "."}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Let me think about this...", + expectedContent: "Here is my answer.", + }, + { + name: "streaming thinking with newlines - granular", + chunks: []string{"Step", " 1", ":", " Ana", "lyze\n", "Step", " 2", ":", " Pro", "cess", "", "\n", "The", " ans", "wer."}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Step 1: Analyze\nStep 2: Process", + expectedContent: "The answer.", + }, + { + name: "streaming tool call - highly granular", + chunks: []string{"<", "tool", "_", "call", ">", "\n", "<", "func", "tion", "=", "get", "_", "weather", ">", "\n", "<", "param", "eter", "=", "city", ">", "\n", "Par", "is", "\n", "", "\n", "", "\n", ""}, + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "Paris"}, + }, + }, + }, + }, + { + name: "streaming content then tool call - granular", + chunks: []string{"Let", " me", " check", " the", " weather", ".", "\n<", "tool_call", ">", "\n", "", "\n", "", "\n", "NYC", "\n", "", "\n", "", "\n", ""}, + thinkValue: nil, + expectedContent: "Let me check the weather.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "NYC"}, + }, + }, + }, + }, + { + name: "tool call tag split character by character", + chunks: []string{"<", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">", "\n", "<", "f", "u", "n", "c", "t", "i", "o", "n", "=", "t", "e", "s", "t", ">", "\n", "<", "/", "f", "u", "n", "c", "t", "i", "o", "n", ">", "\n", "<", "/", "t", "o", "o", "l", "_", "c", "a", "l", "l", ">"}, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test", + Arguments: map[string]any{}, + }, + }, + }, + }, + { + name: "thinking close tag split character by character", + chunks: []string{"I", "'", "m", " ", "t", "h", "i", "n", "k", "i", "n", "g", ".", ".", ".", "<", "/", "t", "h", "i", "n", "k", ">", "\n", "D", "o", "n", "e", "!"}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "I'm thinking...", + expectedContent: "Done!", + }, + { + name: "multiple whitespace after think tag - separate chunks", + chunks: []string{"Thinking...", "", "\n", "\n", " ", "Content here."}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Thinking...", + expectedContent: "Content here.", + }, + { + name: "tool call with multiple parameters - streaming", + chunks: []string{"\n", "", "\n\n", "SFO\n", "", "\n\nNYC", "\n", "\n\n", ""}, + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "book_flight", + Arguments: map[string]any{ + "from": "SFO", + "to": "NYC", + }, + }, + }, + }, + }, + { + name: "thinking then content then tool call - streaming", + chunks: []string{"Ana", "lyzing", " your", " request", "...", "\n", "I'll", " check", " that", " for", " you", ".", "\n", "\n", "\n", "\n", "test", " query", "\n\n", "\n", ""}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Analyzing your request...", + expectedContent: "I'll check that for you.", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "search", + Arguments: map[string]any{"query": "test query"}, + }, + }, + }, + }, + { + name: "multiple tool calls - streaming", + chunks: []string{ + "", "\n", "", "\n", + "\n", "San Fran", "cisco\n", "", "\n", + "", "\n", "", "\n", + "\n", "\n", + "\nNew", " York\n", "\n", + "\n", "", + }, + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "San Francisco"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "New York"}, + }, + }, + }, + }, + { + name: "tool call with multiline parameter - streaming", + chunks: []string{"\n", "\n", "\n", "Line 1", "\nLine", " 2\n", "Line 3", "\n\n", "\n", ""}, + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "create_note", + Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"}, + }, + }, + }, + }, + { + name: "empty thinking block", + chunks: []string{"", "\n", "Just content."}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "", + expectedContent: "Just content.", + }, + { + name: "empty input chunks interspersed", + chunks: []string{"Hello", "", " ", "", "world", "", "!"}, + thinkValue: nil, + expectedContent: "Hello world!", + }, + { + name: "tool call immediately after think close - no content", + chunks: []string{"Analyzing...", "", "\n", "", "\n\n\n", ""}, + thinkValue: &api.ThinkValue{Value: true}, + expectedThinking: "Analyzing...", + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test", + Arguments: map[string]any{}, + }, + }, + }, + }, + { + name: "tool call with empty parameter value", + chunks: []string{"\n\n\n", "\n\n\n"}, + thinkValue: nil, + expectedCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "test", + Arguments: map[string]any{"name": ""}, + }, + }, + }, + }, + { + name: "partial tool call tag at end - buffered", + chunks: []string{"Here's some content", "") + sb.WriteString(systemPrompt.String()) + + // tool definitions + if len(tools) > 0 { + sb.WriteString("\n\n## Tools\nYou have access to the following tools:\n") + + for _, tool := range tools { + sb.WriteString("\n### " + tool.Function.Name) + sb.WriteString("\nDescription: " + tool.Function.Description) + + // parameters as JSON + parametersJSON, err := json.Marshal(tool.Function.Parameters) + if err == nil { + sb.WriteString("\n\nParameters: " + string(parametersJSON) + "\n") + } + } + + // usage instructions + sb.WriteString("\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n") + sb.WriteString("<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\n") + sb.WriteString("Where:\n\n") + sb.WriteString("- `tool_call_name` must be an exact match to one of the available tools\n") + sb.WriteString("- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n") + sb.WriteString("- For multiple tool calls, chain them directly without separators or spaces\n") + } + + // state tracking + isTool := false + isLastUser := false + + // Find the index of the last user message to determine which assistant message is "current" + lastUserIndex := -1 + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + lastUserIndex = i + break + } + } + + for i, message := range messages { + switch message.Role { + case "user": + isTool = false + isLastUser = true + sb.WriteString("<|User|>" + message.Content) + + case "assistant": + if len(message.ToolCalls) > 0 { + if isLastUser { + sb.WriteString("<|Assistant|>") + } + isLastUser = false + isTool = false + + if message.Content != "" { + sb.WriteString(message.Content) + } + + sb.WriteString("<|tool▁calls▁begin|>") + for _, toolCall := range message.ToolCalls { + sb.WriteString("<|tool▁call▁begin|>" + toolCall.Function.Name + "<|tool▁sep|>") + + argsJSON, _ := json.Marshal(toolCall.Function.Arguments) + sb.WriteString(string(argsJSON)) + sb.WriteString("<|tool▁call▁end|>") + } + sb.WriteString("<|tool▁calls▁end|><|end▁of▁sentence|>") + } else { + if isLastUser { + sb.WriteString("<|Assistant|>") + hasThinking := message.Thinking != "" + + // only use for the current turn (after last user message) + isCurrentTurn := i > lastUserIndex + if hasThinking && thinking && isCurrentTurn { + sb.WriteString("") + } else { + sb.WriteString("") + } + } + isLastUser = false + + content := message.Content + if isTool { + sb.WriteString(content + "<|end▁of▁sentence|>") + isTool = false + } else { + if strings.Contains(content, "") { + parts := strings.SplitN(content, "", 2) + if len(parts) > 1 { + content = parts[1] + } + } + sb.WriteString(content + "<|end▁of▁sentence|>") + } + } + + case "tool": + isLastUser = false + isTool = true + sb.WriteString("<|tool▁output▁begin|>" + message.Content + "<|tool▁output▁end|>") + } + } + + if isLastUser && !isTool { + sb.WriteString("<|Assistant|>") + if thinking { + sb.WriteString("") + } else { + sb.WriteString("") + } + } + + return sb.String(), nil +} diff --git a/model/renderers/deepseek3_test.go b/model/renderers/deepseek3_test.go new file mode 100644 index 00000000..c43a9f93 --- /dev/null +++ b/model/renderers/deepseek3_test.go @@ -0,0 +1,1043 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestDeepSeekRenderer(t *testing.T) { + tests := []struct { + name string + messages []api.Message + tools []api.Tool + thinkValue *api.ThinkValue + expected string + }{ + { + name: "basic user message", + messages: []api.Message{ + {Role: "user", Content: "Hello, how are you?"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>Hello, how are you?<|Assistant|>`, + }, + { + name: "basic with system message", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello, how are you?"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|>You are a helpful assistant.<|User|>Hello, how are you?<|Assistant|>`, + }, + { + name: "multiple system messages", + messages: []api.Message{ + {Role: "system", Content: "First instruction"}, + {Role: "system", Content: "Second instruction"}, + {Role: "user", Content: "Hello"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|>First instruction + +Second instruction<|User|>Hello<|Assistant|>`, + }, + { + name: "thinking enabled", + messages: []api.Message{ + {Role: "user", Content: "Hello, how are you?"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|><|User|>Hello, how are you?<|Assistant|>`, + }, + { + name: "thinking enabled with system", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello, how are you?"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|>You are a helpful assistant.<|User|>Hello, how are you?<|Assistant|>`, + }, + { + name: "conversation with assistant response", + messages: []api.Message{ + {Role: "user", Content: "What is the capital of France?"}, + {Role: "assistant", Content: "The capital of France is Paris."}, + {Role: "user", Content: "Fantastic!"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>What is the capital of France?<|Assistant|>The capital of France is Paris.<|end▁of▁sentence|><|User|>Fantastic!<|Assistant|>`, + }, + { + name: "assistant with tool calls", + messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>What's the weather?<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>`, + }, + { + name: "assistant with content and tool calls", + messages: []api.Message{ + {Role: "user", Content: "What's the weather in Paris?"}, + { + Role: "assistant", + Content: "I'll check the weather for you.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>What's the weather in Paris?<|Assistant|>I'll check the weather for you.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>`, + }, + { + name: "tool response", + messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + }, + {Role: "tool", Content: "Temperature: 22°C, Sunny"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>What's the weather?<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁output▁begin|>Temperature: 22°C, Sunny<|tool▁output▁end|>`, + }, + { + name: "multiple tool calls", + messages: []api.Message{ + {Role: "user", Content: "Get weather for Paris and London"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "London", + }, + }, + }, + }, + }, + {Role: "tool", Content: "Paris: 22°C, Sunny"}, + {Role: "tool", Content: "London: 18°C, Cloudy"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>Get weather for Paris and London<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Paris"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"London"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁output▁begin|>Paris: 22°C, Sunny<|tool▁output▁end|><|tool▁output▁begin|>London: 18°C, Cloudy<|tool▁output▁end|>`, + }, + { + name: "content with tag removal", + messages: []api.Message{ + {Role: "user", Content: "Think about this"}, + {Role: "assistant", Content: "I'm thinking about this.The answer is 42."}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>Think about this<|Assistant|>The answer is 42.<|end▁of▁sentence|>`, + }, + { + name: "empty system message", + messages: []api.Message{ + {Role: "system", Content: ""}, + {Role: "user", Content: "Hello"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>Hello<|Assistant|>`, + }, + { + name: "empty assistant content", + messages: []api.Message{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: ""}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>Hello<|Assistant|><|end▁of▁sentence|>`, + }, + { + name: "special characters", + messages: []api.Message{ + {Role: "user", Content: "What about <|special|> tokens and \"quotes\"?"}, + {Role: "assistant", Content: "They're handled normally."}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>What about <|special|> tokens and "quotes"?<|Assistant|>They're handled normally.<|end▁of▁sentence|>`, + }, + { + name: "tool calls with null content", + messages: []api.Message{ + {Role: "user", Content: "Get weather"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>Get weather<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>`, + }, + { + name: "assistant after tool context", + messages: []api.Message{ + {Role: "user", Content: "Process data"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "process", + Arguments: api.ToolCallFunctionArguments{ + "data": "test", + }, + }, + }, + }, + }, + {Role: "tool", Content: "Success"}, + {Role: "assistant", Content: "Done"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>Process data<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>process<|tool▁sep|>{"data":"test"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁output▁begin|>Success<|tool▁output▁end|>Done<|end▁of▁sentence|>`, + }, + { + name: "no messages", + messages: []api.Message{}, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|>`, + }, + { + name: "only system messages", + messages: []api.Message{ + {Role: "system", Content: "System instruction"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|>System instruction`, + }, + { + name: "multiple think tags in content", + messages: []api.Message{ + {Role: "user", Content: "Complex question"}, + {Role: "assistant", Content: "First thoughtSecond thoughtFinal answer"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>Complex question<|Assistant|>Second thoughtFinal answer<|end▁of▁sentence|>`, + }, + { + name: "thinking enabled after tool call - should render thinking", + messages: []api.Message{ + {Role: "user", Content: "What's the weather in Paris?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + }, + {Role: "tool", Content: "Temperature: 22°C, Sunny"}, + {Role: "assistant", Content: "Based on the weather data, it's sunny in Paris."}, + {Role: "user", Content: "Now tell me about London weather too."}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|><|User|>What's the weather in Paris?<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁output▁begin|>Temperature: 22°C, Sunny<|tool▁output▁end|>Based on the weather data, it's sunny in Paris.<|end▁of▁sentence|><|User|>Now tell me about London weather too.<|Assistant|>`, + }, + { + name: "thinking disabled after tool call - should not render thinking", + messages: []api.Message{ + {Role: "user", Content: "What's the weather in Paris?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + }, + {Role: "tool", Content: "Temperature: 22°C, Sunny"}, + {Role: "assistant", Content: "Based on the weather data, it's sunny in Paris."}, + {Role: "user", Content: "Now tell me about London weather too."}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>What's the weather in Paris?<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁output▁begin|>Temperature: 22°C, Sunny<|tool▁output▁end|>Based on the weather data, it's sunny in Paris.<|end▁of▁sentence|><|User|>Now tell me about London weather too.<|Assistant|>`, + }, + { + name: "thinking enabled but messages without thinking content", + messages: []api.Message{ + {Role: "user", Content: "First question about cats"}, + {Role: "assistant", Content: "Cats are wonderful pets."}, + {Role: "user", Content: "What about dogs?"}, + {Role: "assistant", Content: "Dogs are loyal companions."}, + {Role: "user", Content: "Final question about birds"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|><|User|>First question about cats<|Assistant|>Cats are wonderful pets.<|end▁of▁sentence|><|User|>What about dogs?<|Assistant|>Dogs are loyal companions.<|end▁of▁sentence|><|User|>Final question about birds<|Assistant|>`, + }, + { + name: "thinking disabled for all assistant responses", + messages: []api.Message{ + {Role: "user", Content: "First question about cats"}, + {Role: "assistant", Content: "Cats are wonderful pets."}, + {Role: "user", Content: "What about dogs?"}, + {Role: "assistant", Content: "Dogs are loyal companions."}, + {Role: "user", Content: "Final question about birds"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>First question about cats<|Assistant|>Cats are wonderful pets.<|end▁of▁sentence|><|User|>What about dogs?<|Assistant|>Dogs are loyal companions.<|end▁of▁sentence|><|User|>Final question about birds<|Assistant|>`, + }, + { + name: "complex conversation with tool calls and thinking enabled", + messages: []api.Message{ + {Role: "user", Content: "Tell me about the weather"}, + {Role: "assistant", Content: "I'll check the weather for you."}, + {Role: "user", Content: "Actually, get Paris weather specifically"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + }, + {Role: "tool", Content: "Paris: 22°C, Sunny"}, + {Role: "assistant", Content: "The weather in Paris is great!"}, + {Role: "user", Content: "What about the forecast for tomorrow?"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|><|User|>Tell me about the weather<|Assistant|>I'll check the weather for you.<|end▁of▁sentence|><|User|>Actually, get Paris weather specifically<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁output▁begin|>Paris: 22°C, Sunny<|tool▁output▁end|>The weather in Paris is great!<|end▁of▁sentence|><|User|>What about the forecast for tomorrow?<|Assistant|>`, + }, + { + name: "tool call without subsequent user message - no thinking", + messages: []api.Message{ + {Role: "user", Content: "Get the weather"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + }, + {Role: "tool", Content: "22°C, Sunny"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|><|User|>Get the weather<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁output▁begin|>22°C, Sunny<|tool▁output▁end|>`, + }, + { + name: "messages with thinking content, no thinking in render", + messages: []api.Message{ + {Role: "user", Content: "Solve this math problem: 15 * 23"}, + { + Role: "assistant", + Content: "The answer is 345.", + Thinking: "Let me calculate 15 * 23. I can break this down: 15 * 20 = 300, and 15 * 3 = 45, so 300 + 45 = 345.", + }, + {Role: "user", Content: "What about 12 * 34?"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|><|User|>Solve this math problem: 15 * 23<|Assistant|>The answer is 345.<|end▁of▁sentence|><|User|>What about 12 * 34?<|Assistant|>`, + }, + { + name: "conversation with mix of thinking and no thinking", + messages: []api.Message{ + {Role: "user", Content: "Explain quantum physics"}, + { + Role: "assistant", + Content: "Quantum physics is the study of matter and energy at the smallest scales.", + Thinking: "This is a complex topic. I should start with basic concepts and avoid overwhelming technical details.", + }, + {Role: "user", Content: "What about photons?"}, + { + Role: "assistant", + Content: "Photons are particles of light with no mass.", + }, + {Role: "user", Content: "How do they interact with matter?"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|><|User|>Explain quantum physics<|Assistant|>Quantum physics is the study of matter and energy at the smallest scales.<|end▁of▁sentence|><|User|>What about photons?<|Assistant|>Photons are particles of light with no mass.<|end▁of▁sentence|><|User|>How do they interact with matter?<|Assistant|>`, + }, + { + name: "tool call with thinking content in response", + messages: []api.Message{ + {Role: "user", Content: "What's the weather in Tokyo and New York?"}, + { + Role: "assistant", + Content: "I'll check the weather for both cities.", + Thinking: "I need to call the weather API for two different cities. Let me make parallel calls.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo", + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "New York", + }, + }, + }, + }, + }, + {Role: "tool", Content: "Tokyo: 18°C, Cloudy"}, + {Role: "tool", Content: "New York: 22°C, Sunny"}, + { + Role: "assistant", + Content: "Based on the weather data: Tokyo is cloudy at 18°C, while New York is sunny at 22°C.", + Thinking: "The data shows a nice contrast between the two cities. Tokyo is cooler and overcast while NYC has better weather.", + }, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|><|User|>What's the weather in Tokyo and New York?<|Assistant|>I'll check the weather for both cities.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"New York"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁output▁begin|>Tokyo: 18°C, Cloudy<|tool▁output▁end|><|tool▁output▁begin|>New York: 22°C, Sunny<|tool▁output▁end|>Based on the weather data: Tokyo is cloudy at 18°C, while New York is sunny at 22°C.<|end▁of▁sentence|>`, + }, + { + name: "empty thinking field", + messages: []api.Message{ + {Role: "user", Content: "Simple question"}, + { + Role: "assistant", + Content: "Simple answer.", + Thinking: "", // Empty thinking content + }, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|><|User|>Simple question<|Assistant|>Simple answer.<|end▁of▁sentence|>`, + }, + { + name: "with tools definitions", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "What's the weather like?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get current weather information", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "City name", + }, + }, + Required: []string{"location"}, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|>You are a helpful assistant. + +## Tools +You have access to the following tools: + +### get_weather +Description: Get current weather information + +Parameters: {"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}} + +IMPORTANT: ALWAYS adhere to this exact format for tool use: +<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|> + +Where: + +- ` + "`tool_call_name`" + ` must be an exact match to one of the available tools +- ` + "`tool_call_arguments`" + ` must be valid JSON that strictly follows the tool's Parameters Schema +- For multiple tool calls, chain them directly without separators or spaces +<|User|>What's the weather like?<|Assistant|>`, + }, + { + name: "tools definitions with thinking enabled", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "What's the weather in Paris?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get current weather information", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "City name", + }, + }, + Required: []string{"location"}, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|>You are a helpful assistant. + +## Tools +You have access to the following tools: + +### get_weather +Description: Get current weather information + +Parameters: {"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}} + +IMPORTANT: ALWAYS adhere to this exact format for tool use: +<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|> + +Where: + +- ` + "`tool_call_name`" + ` must be an exact match to one of the available tools +- ` + "`tool_call_arguments`" + ` must be valid JSON that strictly follows the tool's Parameters Schema +- For multiple tool calls, chain them directly without separators or spaces +<|User|>What's the weather in Paris?<|Assistant|>`, + }, + { + name: "tools definitions with actual tool call", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "What's the weather in Paris?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + }, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get current weather information", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "City name", + }, + }, + Required: []string{"location"}, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|>You are a helpful assistant. + +## Tools +You have access to the following tools: + +### get_weather +Description: Get current weather information + +Parameters: {"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}} + +IMPORTANT: ALWAYS adhere to this exact format for tool use: +<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|> + +Where: + +- ` + "`tool_call_name`" + ` must be an exact match to one of the available tools +- ` + "`tool_call_arguments`" + ` must be valid JSON that strictly follows the tool's Parameters Schema +- For multiple tool calls, chain them directly without separators or spaces +<|User|>What's the weather in Paris?<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>`, + }, + { + name: "tools definitions with full conversation cycle", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "What's the weather in Paris?"}, + { + Role: "assistant", + Content: "I'll check the weather for you.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + }, + }, + {Role: "tool", Content: "Temperature: 22°C, Sunny"}, + {Role: "assistant", Content: "The weather in Paris is 22°C and sunny!"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get current weather information", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "City name", + }, + }, + Required: []string{"location"}, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|>You are a helpful assistant. + +## Tools +You have access to the following tools: + +### get_weather +Description: Get current weather information + +Parameters: {"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}} + +IMPORTANT: ALWAYS adhere to this exact format for tool use: +<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|> + +Where: + +- ` + "`tool_call_name`" + ` must be an exact match to one of the available tools +- ` + "`tool_call_arguments`" + ` must be valid JSON that strictly follows the tool's Parameters Schema +- For multiple tool calls, chain them directly without separators or spaces +<|User|>What's the weather in Paris?<|Assistant|>I'll check the weather for you.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁output▁begin|>Temperature: 22°C, Sunny<|tool▁output▁end|>The weather in Paris is 22°C and sunny!<|end▁of▁sentence|>`, + }, + { + name: "tools with thinking and full conversation", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Check the weather in Tokyo"}, + { + Role: "assistant", + Thinking: "The user wants weather info for Tokyo. I should use the get_weather tool.", + Content: "Let me check that for you.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo", + }, + }, + }, + }, + }, + {Role: "tool", Content: "Temperature: 18°C, Cloudy"}, + { + Role: "assistant", + Thinking: "The weather data shows it's cloudy and cool. I should present this clearly.", + Content: "In Tokyo, it's currently 18°C and cloudy.", + }, + {Role: "user", Content: "What about London?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get current weather information", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "City name", + }, + }, + Required: []string{"location"}, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|>You are a helpful assistant. + +## Tools +You have access to the following tools: + +### get_weather +Description: Get current weather information + +Parameters: {"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}} + +IMPORTANT: ALWAYS adhere to this exact format for tool use: +<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|> + +Where: + +- ` + "`tool_call_name`" + ` must be an exact match to one of the available tools +- ` + "`tool_call_arguments`" + ` must be valid JSON that strictly follows the tool's Parameters Schema +- For multiple tool calls, chain them directly without separators or spaces +<|User|>Check the weather in Tokyo<|Assistant|>Let me check that for you.<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Tokyo"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁output▁begin|>Temperature: 18°C, Cloudy<|tool▁output▁end|>In Tokyo, it's currently 18°C and cloudy.<|end▁of▁sentence|><|User|>What about London?<|Assistant|>`, + }, + { + name: "multiple tools definitions", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to multiple tools."}, + {Role: "user", Content: "What can you help me with?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get current weather information", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "City name", + }, + }, + Required: []string{"location"}, + }, + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "calculate", + Description: "Perform mathematical calculations", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "expression": { + Type: api.PropertyType{"string"}, + Description: "Mathematical expression to evaluate", + }, + }, + Required: []string{"expression"}, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|>You are a helpful assistant with access to multiple tools. + +## Tools +You have access to the following tools: + +### get_weather +Description: Get current weather information + +Parameters: {"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}} + +### calculate +Description: Perform mathematical calculations + +Parameters: {"type":"object","required":["expression"],"properties":{"expression":{"type":"string","description":"Mathematical expression to evaluate"}}} + +IMPORTANT: ALWAYS adhere to this exact format for tool use: +<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|> + +Where: + +- ` + "`tool_call_name`" + ` must be an exact match to one of the available tools +- ` + "`tool_call_arguments`" + ` must be valid JSON that strictly follows the tool's Parameters Schema +- For multiple tool calls, chain them directly without separators or spaces +<|User|>What can you help me with?<|Assistant|>`, + }, + { + name: "multiple tools with multiple tool calls", + messages: []api.Message{ + {Role: "user", Content: "Get weather for Paris and calculate 25 * 4"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Paris", + }, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "calculate", + Arguments: api.ToolCallFunctionArguments{ + "expression": "25 * 4", + }, + }, + }, + }, + }, + {Role: "tool", Content: "Temperature: 22°C, Sunny"}, + {Role: "tool", Content: "Result: 100"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get current weather information", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "City name", + }, + }, + Required: []string{"location"}, + }, + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "calculate", + Description: "Perform mathematical calculations", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "expression": { + Type: api.PropertyType{"string"}, + Description: "Mathematical expression to evaluate", + }, + }, + Required: []string{"expression"}, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|> + +## Tools +You have access to the following tools: + +### get_weather +Description: Get current weather information + +Parameters: {"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}} + +### calculate +Description: Perform mathematical calculations + +Parameters: {"type":"object","required":["expression"],"properties":{"expression":{"type":"string","description":"Mathematical expression to evaluate"}}} + +IMPORTANT: ALWAYS adhere to this exact format for tool use: +<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|> + +Where: + +- ` + "`tool_call_name`" + ` must be an exact match to one of the available tools +- ` + "`tool_call_arguments`" + ` must be valid JSON that strictly follows the tool's Parameters Schema +- For multiple tool calls, chain them directly without separators or spaces +<|User|>Get weather for Paris and calculate 25 * 4<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Paris"}<|tool▁call▁end|><|tool▁call▁begin|>calculate<|tool▁sep|>{"expression":"25 * 4"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁output▁begin|>Temperature: 22°C, Sunny<|tool▁output▁end|><|tool▁output▁begin|>Result: 100<|tool▁output▁end|>`, + }, + { + name: "tools without system message", + messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get current weather information", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "City name", + }, + }, + Required: []string{"location"}, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `<|begin▁of▁sentence|> + +## Tools +You have access to the following tools: + +### get_weather +Description: Get current weather information + +Parameters: {"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}} + +IMPORTANT: ALWAYS adhere to this exact format for tool use: +<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|> + +Where: + +- ` + "`tool_call_name`" + ` must be an exact match to one of the available tools +- ` + "`tool_call_arguments`" + ` must be valid JSON that strictly follows the tool's Parameters Schema +- For multiple tool calls, chain them directly without separators or spaces +<|User|>What's the weather?<|Assistant|>`, + }, + { + name: "multi-turn conversation with thinking content on each turn", + messages: []api.Message{ + {Role: "user", Content: "hey!"}, + { + Role: "assistant", + Content: "Hey! 😊 How's it going? What's on your mind today?", + Thinking: "Hmm, the user just said \"hey!\" which is a simple greeting. This is a straightforward opening where they're likely just starting a conversation or testing the interaction.", + }, + {Role: "user", Content: "fantastic, how has yours been"}, + { + Role: "assistant", + Content: "Glad to hear you're having a fantastic day! That's awesome.\n\nMine's been great, thanks for asking! Just buzzing along, helping people out and having conversations like this one. So what's making your day so fantastic? Anything fun happening?", + Thinking: "Ah, the user is responding warmly and asking about my \"day.\" Since I'm an AI, I need to gently remind them I don't experience time like a human, but frame it positively to keep the conversation flowing.", + }, + {Role: "user", Content: "awesome, can you tell me a 10 word story?"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|><|User|>hey!<|Assistant|>Hey! 😊 How's it going? What's on your mind today?<|end▁of▁sentence|><|User|>fantastic, how has yours been<|Assistant|>Glad to hear you're having a fantastic day! That's awesome. + +Mine's been great, thanks for asking! Just buzzing along, helping people out and having conversations like this one. So what's making your day so fantastic? Anything fun happening?<|end▁of▁sentence|><|User|>awesome, can you tell me a 10 word story?<|Assistant|>`, + }, + { + name: "vLLM documentation example - multi-turn with full thinking content", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant"}, + {Role: "user", Content: "Who are you?"}, + { + Role: "assistant", + Content: "I am DeepSeek", + Thinking: "Hmm", + }, + {Role: "user", Content: "9.11 and 9.8, which is greater?"}, + { + Role: "assistant", + Content: "9.8 is greater than 9.11.\n\nTo compare them easily, you can align the decimal places: \n- 9.11 has a tenths digit of 1 and a hundredths digit of 1. \n- 9.8 can be written as 9.80, which has a tenths digit of 8 and a hundredths digit of 0. \n\nSince the whole number part (9) is the same, compare the tenths place: 8 is greater than 1, so 9.80 (or 9.8) is greater than 9.11.", + Thinking: "First, the user is asking which number is greater between 9.11 and 9.8. These are decimal numbers.\n\nI need to compare 9.11 and 9.8. To make it easier, I should think of them as decimals with the same number of decimal places.\n\n9.11 has two decimal places, and 9.8 has one decimal place. I can write 9.8 as 9.80 to make it comparable.\n\nSo, 9.11 versus 9.80.\n\nNow, comparing the whole number parts: both have 9, so they are equal in the units place.\n\nNext, compare the tenths place: for 9.11, the tenths digit is 1. For 9.80, the tenths digit is 8. Since 8 is greater than 1, 9.80 is greater than 9.11.\n\nTherefore, 9.8 is greater than 9.11.\n\nI can also think of them as fractions: 9.11 is 911/100, and 9.8 is 98/10 or 980/100. Comparing 911/100 and 980/100, 980/100 is larger, so 9.8 is greater.\n\nSo, the answer should be that 9.8 is greater than 9.11.\n\nNow, I need to respond helpfully. Since the user might be learning or need clarification, I should explain briefly.\n\nFinally, my response should be in English, as the query is in English.", + }, + {Role: "user", Content: "Thanks! Can you tell me a 10 word story?"}, + { + Role: "assistant", + Content: "Of course. Here is a 10-word story:\n\nHe found the key, unlocking a door to forgotten memories.", + Thinking: "Hmm, the user just asked for a 10-word story after the previous number comparison question. This is a quick creative task with a strict word count constraint. \n\nThe story needs to be exactly 10 words while maintaining coherence and a hint of narrative. A micro-story about finding a key could work - it implies a larger unseen story. \n\nChecking the word count: \"He found the key, unlocking a door to forgotten memories.\" That's 10 words with a beginning, middle and implied end. It fits the requirement while leaving room for imagination.", + }, + {Role: "user", Content: "That was beautiful! Now can you write a haiku?"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|>You are a helpful assistant<|User|>Who are you?<|Assistant|>I am DeepSeek<|end▁of▁sentence|><|User|>9.11 and 9.8, which is greater?<|Assistant|>9.8 is greater than 9.11. + +To compare them easily, you can align the decimal places: +- 9.11 has a tenths digit of 1 and a hundredths digit of 1. +- 9.8 can be written as 9.80, which has a tenths digit of 8 and a hundredths digit of 0. + +Since the whole number part (9) is the same, compare the tenths place: 8 is greater than 1, so 9.80 (or 9.8) is greater than 9.11.<|end▁of▁sentence|><|User|>Thanks! Can you tell me a 10 word story?<|Assistant|>Of course. Here is a 10-word story: + +He found the key, unlocking a door to forgotten memories.<|end▁of▁sentence|><|User|>That was beautiful! Now can you write a haiku?<|Assistant|>`, + }, + { + name: "no system prompt - content with embedded thinking tags", + messages: []api.Message{ + {Role: "user", Content: "Who are you?"}, + {Role: "assistant", Content: "HmmI am DeepSeek"}, + {Role: "user", Content: "Thanks! Can you tell me a 10 word story?"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `<|begin▁of▁sentence|><|User|>Who are you?<|Assistant|>I am DeepSeek<|end▁of▁sentence|><|User|>Thanks! Can you tell me a 10 word story?<|Assistant|>`, + }, + } + + renderer := &DeepSeek3Renderer{IsThinking: true, Variant: Deepseek31} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue) + if err != nil { + t.Fatalf("Render() error = %v", err) + } + if diff := cmp.Diff(tt.expected, rendered); diff != "" { + t.Errorf("Render() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/model/renderers/functiongemma.go b/model/renderers/functiongemma.go new file mode 100644 index 00000000..dcbcc062 --- /dev/null +++ b/model/renderers/functiongemma.go @@ -0,0 +1,287 @@ +package renderers + +import ( + "fmt" + "sort" + "strings" + + "github.com/ollama/ollama/api" +) + +type FunctionGemmaRenderer struct{} + +const defaultSystemMessage = "You can do function calling with the following functions:" + +func (r *FunctionGemmaRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) { + var sb strings.Builder + + sb.WriteString("") + + var systemMessage string + var loopMessages []api.Message + if len(messages) > 0 && (messages[0].Role == "system" || messages[0].Role == "developer") { + systemMessage = messages[0].Content + loopMessages = messages[1:] + } else { + loopMessages = messages + } + + if systemMessage != "" || len(tools) > 0 { + sb.WriteString("developer\n") + if systemMessage != "" { + sb.WriteString(strings.TrimSpace(systemMessage)) + } + if len(tools) > 0 { + if systemMessage != "" { + sb.WriteString("\n") + } + if strings.TrimSpace(systemMessage) != defaultSystemMessage { + // Only add default message if user does not provide it + sb.WriteString(defaultSystemMessage) + } + } + for _, tool := range tools { + sb.WriteString(r.renderToolDeclaration(tool)) + } + sb.WriteString("\n") + } + + // Track previous message type for tool response handling + prevMessageType := "" + + for i, message := range loopMessages { + switch message.Role { + case "assistant": + if prevMessageType != "tool_response" { + sb.WriteString("model\n") + } + prevMessageType = "" + + if message.Content != "" { + sb.WriteString(strings.TrimSpace(message.Content)) + } + + if len(message.ToolCalls) > 0 { + for _, tc := range message.ToolCalls { + sb.WriteString(r.formatToolCall(tc)) + } + // After tool calls, expect tool responses + if i+1 < len(loopMessages) && loopMessages[i+1].Role == "tool" { + sb.WriteString("") + prevMessageType = "tool_call" + } else { + sb.WriteString("\n") + } + } else { + sb.WriteString("\n") + } + + case "user": + if prevMessageType != "tool_response" { + sb.WriteString("user\n") + } + prevMessageType = "" + sb.WriteString(strings.TrimSpace(message.Content)) + sb.WriteString("\n") + + case "tool": + toolName := "" + // Find the tool name from the previous assistant's tool call + for j := i - 1; j >= 0; j-- { + if loopMessages[j].Role == "assistant" && len(loopMessages[j].ToolCalls) > 0 { + // Count how many tool messages came before this one + toolIdx := 0 + for k := j + 1; k < i; k++ { + if loopMessages[k].Role == "tool" { + toolIdx++ + } + } + if toolIdx < len(loopMessages[j].ToolCalls) { + toolName = loopMessages[j].ToolCalls[toolIdx].Function.Name + } + break + } + } + + if prevMessageType != "tool_call" { + sb.WriteString("") + } + sb.WriteString("response:" + toolName + "{" + r.formatArgValue(message.Content) + "}") + prevMessageType = "tool_response" + + default: + sb.WriteString("" + message.Role + "\n") + sb.WriteString(strings.TrimSpace(message.Content)) + sb.WriteString("\n") + } + } + + if prevMessageType != "tool_response" { + sb.WriteString("model\n") + } + + return sb.String(), nil +} + +func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string { + var sb strings.Builder + + fn := tool.Function + sb.WriteString("declaration:" + fn.Name + "{") + sb.WriteString("description:" + fn.Description + "") + + if fn.Parameters.Properties != nil || fn.Parameters.Type != "" { + sb.WriteString(",parameters:{") + + needsComma := false + + // Only include properties:{} if there are actual properties + if len(fn.Parameters.Properties) > 0 { + sb.WriteString("properties:{") + r.writeProperties(&sb, fn.Parameters.Properties) + sb.WriteString("}") + needsComma = true + } + + if len(fn.Parameters.Required) > 0 { + if needsComma { + sb.WriteString(",") + } + sb.WriteString("required:[") + for i, req := range fn.Parameters.Required { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString("" + req + "") + } + sb.WriteString("]") + needsComma = true + } + + if fn.Parameters.Type != "" { + if needsComma { + sb.WriteString(",") + } + sb.WriteString("type:" + strings.ToUpper(fn.Parameters.Type) + "") + } + + sb.WriteString("}") + } + + sb.WriteString("}") + return sb.String() +} + +func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props map[string]api.ToolProperty) { + keys := make([]string, 0, len(props)) + for k := range props { + keys = append(keys, k) + } + sort.Strings(keys) + + first := true + for _, name := range keys { + prop := props[name] + if !first { + sb.WriteString(",") + } + first = false + + sb.WriteString(name + ":{description:") + sb.WriteString(prop.Description) + sb.WriteString("") + + if len(prop.Type) > 0 { + sb.WriteString(",type:" + strings.ToUpper(prop.Type[0]) + "") + } + + sb.WriteString("}") + } +} + +func (r *FunctionGemmaRenderer) formatToolCall(tc api.ToolCall) string { + var sb strings.Builder + sb.WriteString("call:" + tc.Function.Name + "{") + + keys := make([]string, 0, len(tc.Function.Arguments)) + for k := range tc.Function.Arguments { + keys = append(keys, k) + } + sort.Strings(keys) + + first := true + for _, key := range keys { + value := tc.Function.Arguments[key] + if !first { + sb.WriteString(",") + } + first = false + sb.WriteString(key + ":" + r.formatArgValue(value)) + } + + sb.WriteString("}") + return sb.String() +} + +func (r *FunctionGemmaRenderer) formatArgValue(value any) string { + switch v := value.(type) { + case string: + return "" + v + "" + case bool: + if v { + return "true" + } + return "false" + case float64: + if v == float64(int64(v)) { + return fmt.Sprintf("%d", int64(v)) + } + return fmt.Sprintf("%v", v) + case int, int64, int32: + return fmt.Sprintf("%d", v) + case map[string]any: + return r.formatMapValue(v) + case []any: + return r.formatArrayValue(v) + default: + return fmt.Sprintf("%v", v) + } +} + +func (r *FunctionGemmaRenderer) formatMapValue(m map[string]any) string { + var sb strings.Builder + sb.WriteString("{") + + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + first := true + for _, key := range keys { + if !first { + sb.WriteString(",") + } + first = false + sb.WriteString(key + ":" + r.formatArgValue(m[key])) + } + + sb.WriteString("}") + return sb.String() +} + +func (r *FunctionGemmaRenderer) formatArrayValue(arr []any) string { + var sb strings.Builder + sb.WriteString("[") + + for i, item := range arr { + if i > 0 { + sb.WriteString(",") + } + sb.WriteString(r.formatArgValue(item)) + } + + sb.WriteString("]") + return sb.String() +} diff --git a/model/renderers/functiongemma_test.go b/model/renderers/functiongemma_test.go new file mode 100644 index 00000000..733ff374 --- /dev/null +++ b/model/renderers/functiongemma_test.go @@ -0,0 +1,514 @@ +package renderers + +import ( + "testing" + + "github.com/ollama/ollama/api" + "github.com/stretchr/testify/assert" +) + +func TestFunctionGemmaRenderer(t *testing.T) { + tests := []struct { + name string + messages []api.Message + tools []api.Tool + expected string + }{ + { + name: "basic_user_message", + messages: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + expected: "user\nHello!\nmodel\n", + }, + { + name: "with_system_message", + messages: []api.Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hello!"}, + }, + expected: "developer\nYou are helpful\nuser\nHello!\nmodel\n", + }, + { + name: "with_developer_role", + messages: []api.Message{ + {Role: "developer", Content: "You are a coding assistant"}, + {Role: "user", Content: "Hello!"}, + }, + expected: "developer\nYou are a coding assistant\nuser\nHello!\nmodel\n", + }, + { + name: "custom_system_message_with_tools", + messages: []api.Message{ + {Role: "system", Content: "You are a weather expert."}, + {Role: "user", Content: "Weather?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "City"}, + }, + }, + }, + }, + }, + // Custom system message is preserved, tools are appended + expected: "developer\nYou are a weather expert.\nYou can do function calling with the following functions:declaration:get_weather{description:Get weather,parameters:{properties:{city:{description:City,type:STRING}},type:OBJECT}}\nuser\nWeather?\nmodel\n", + }, + { + name: "developer_role_with_tools", + messages: []api.Message{ + {Role: "developer", Content: "Be concise."}, + {Role: "user", Content: "Weather?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "City"}, + }, + }, + }, + }, + }, + // Developer role message is preserved, tools are appended + expected: "developer\nBe concise.\nYou can do function calling with the following functions:declaration:get_weather{description:Get weather,parameters:{properties:{city:{description:City,type:STRING}},type:OBJECT}}\nuser\nWeather?\nmodel\n", + }, + { + name: "multi_turn", + messages: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Hello!"}, + {Role: "user", Content: "More"}, + }, + expected: "user\nHi\nmodel\nHello!\nuser\nMore\nmodel\n", + }, + { + name: "with_tools", + messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "City"}, + }, + }, + }, + }, + }, + expected: "developer\nYou can do function calling with the following functions:declaration:get_weather{description:Get weather,parameters:{properties:{city:{description:City,type:STRING}},type:OBJECT}}\nuser\nWeather?\nmodel\n", + }, + { + name: "tool_call", + messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "City"}, + }, + }, + }, + }, + }, + expected: "developer\nYou can do function calling with the following functions:declaration:get_weather{description:Get weather,parameters:{properties:{city:{description:City,type:STRING}},type:OBJECT}}\nuser\nWeather?\nmodel\ncall:get_weather{city:Paris}response:get_weather{Sunny}", + }, + { + name: "assistant_content_with_tool_call", + messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + { + Role: "assistant", + Content: "Let me check.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "City"}, + }, + }, + }, + }, + }, + expected: "developer\nYou can do function calling with the following functions:declaration:get_weather{description:Get weather,parameters:{properties:{city:{description:City,type:STRING}},type:OBJECT}}\nuser\nWeather?\nmodel\nLet me check.call:get_weather{city:Paris}response:get_weather{Sunny}", + }, + { + name: "numeric_arguments", + messages: []api.Message{ + {Role: "user", Content: "Add"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "add", + Arguments: api.ToolCallFunctionArguments{"a": float64(1), "b": float64(2)}, + }, + }, + }, + }, + {Role: "tool", Content: "3"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "add", + Description: "Add numbers", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "a": {Type: api.PropertyType{"number"}}, + "b": {Type: api.PropertyType{"number"}}, + }, + }, + }, + }, + }, + expected: "developer\nYou can do function calling with the following functions:declaration:add{description:Add numbers,parameters:{properties:{a:{description:,type:NUMBER},b:{description:,type:NUMBER}},type:OBJECT}}\nuser\nAdd\nmodel\ncall:add{a:1,b:2}response:add{3}", + }, + { + name: "empty_messages", + messages: []api.Message{}, + expected: "model\n", + }, + { + name: "tool_with_required_params", + messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Gets the weather for a given city", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"city"}, + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "City Name"}, + "country": {Type: api.PropertyType{"string"}, Description: "Country Name"}, + }, + }, + }, + }, + }, + // Required params are escaped: required:[city] + expected: "developer\nYou can do function calling with the following functions:declaration:get_weather{description:Gets the weather for a given city,parameters:{properties:{city:{description:City Name,type:STRING},country:{description:Country Name,type:STRING}},required:[city],type:OBJECT}}\nuser\nWeather?\nmodel\n", + }, + { + name: "multiple_tools", + messages: []api.Message{ + {Role: "user", Content: "Weather and time?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "City"}, + }, + }, + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_time", + Description: "Get current time", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"}, + }, + }, + }, + }, + }, + // Multiple tool declarations are consecutive + expected: "developer\nYou can do function calling with the following functions:declaration:get_weather{description:Get weather,parameters:{properties:{city:{description:City,type:STRING}},type:OBJECT}}declaration:get_time{description:Get current time,parameters:{properties:{timezone:{description:Timezone,type:STRING}},type:OBJECT}}\nuser\nWeather and time?\nmodel\n", + }, + { + name: "parallel_tool_calls", + messages: []api.Message{ + {Role: "user", Content: "Weather and time?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_time", + Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"}, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny"}, + {Role: "tool", Content: "12:00"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "City"}, + }, + }, + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_time", + Description: "Get current time", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"}, + }, + }, + }, + }, + }, + // Multiple tool calls and responses are consecutive + expected: "developer\nYou can do function calling with the following functions:declaration:get_weather{description:Get weather,parameters:{properties:{city:{description:City,type:STRING}},type:OBJECT}}declaration:get_time{description:Get current time,parameters:{properties:{timezone:{description:Timezone,type:STRING}},type:OBJECT}}\nuser\nWeather and time?\nmodel\ncall:get_weather{city:Paris}call:get_time{timezone:UTC}response:get_weather{Sunny}response:get_time{12:00}", + }, + { + name: "user_after_tool_response", + messages: []api.Message{ + {Role: "user", Content: "Weather?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"city": "Paris"}, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny"}, + {Role: "user", Content: "Thanks! What about London?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "City"}, + }, + }, + }, + }, + }, + // User message after tool response gets concatenated (user reverted to this behavior) + expected: "developer\nYou can do function calling with the following functions:declaration:get_weather{description:Get weather,parameters:{properties:{city:{description:City,type:STRING}},type:OBJECT}}\nuser\nWeather?\nmodel\ncall:get_weather{city:Paris}response:get_weather{Sunny}Thanks! What about London?\nmodel\n", + }, + // Edge cases + { + name: "tool_empty_properties", + messages: []api.Message{ + {Role: "user", Content: "Test"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "test_fn", + Description: "", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{}, + }, + }, + }, + }, + // Empty properties are omitted + expected: "developer\nYou can do function calling with the following functions:declaration:test_fn{description:,parameters:{type:OBJECT}}\nuser\nTest\nmodel\n", + }, + { + name: "unicode_content", + messages: []api.Message{ + {Role: "user", Content: "こんにちは 🎉"}, + }, + expected: "user\nこんにちは 🎉\nmodel\n", + }, + { + name: "newlines_in_content", + messages: []api.Message{ + {Role: "user", Content: "Line 1\nLine 2\nLine 3"}, + }, + expected: "user\nLine 1\nLine 2\nLine 3\nmodel\n", + }, + { + name: "special_chars_in_content", + messages: []api.Message{ + {Role: "user", Content: "Test & \"quotes\" chars"}, + }, + expected: "user\nTest & \"quotes\" chars\nmodel\n", + }, + { + name: "boolean_argument", + messages: []api.Message{ + {Role: "user", Content: "Set flag"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "set_flag", + Arguments: api.ToolCallFunctionArguments{"enabled": true}, + }, + }, + }, + }, + {Role: "tool", Content: "done"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "set_flag", + Description: "Set a flag", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "enabled": {Type: api.PropertyType{"boolean"}, Description: "Flag value"}, + }, + }, + }, + }, + }, + expected: "developer\nYou can do function calling with the following functions:declaration:set_flag{description:Set a flag,parameters:{properties:{enabled:{description:Flag value,type:BOOLEAN}},type:OBJECT}}\nuser\nSet flag\nmodel\ncall:set_flag{enabled:true}response:set_flag{done}", + }, + { + name: "multiple_required_params", + messages: []api.Message{ + {Role: "user", Content: "Test"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "test", + Description: "Test", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"a", "b", "c"}, + Properties: map[string]api.ToolProperty{ + "a": {Type: api.PropertyType{"string"}, Description: "A"}, + "b": {Type: api.PropertyType{"string"}, Description: "B"}, + "c": {Type: api.PropertyType{"string"}, Description: "C"}, + }, + }, + }, + }, + }, + expected: "developer\nYou can do function calling with the following functions:declaration:test{description:Test,parameters:{properties:{a:{description:A,type:STRING},b:{description:B,type:STRING},c:{description:C,type:STRING}},required:[a,b,c],type:OBJECT}}\nuser\nTest\nmodel\n", + }, + { + name: "array_type_param", + messages: []api.Message{ + {Role: "user", Content: "Test"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "test", + Description: "Test", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "items": {Type: api.PropertyType{"array"}, Description: "List of items"}, + }, + }, + }, + }, + }, + expected: "developer\nYou can do function calling with the following functions:declaration:test{description:Test,parameters:{properties:{items:{description:List of items,type:ARRAY}},type:OBJECT}}\nuser\nTest\nmodel\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + renderer := &FunctionGemmaRenderer{} + result, err := renderer.Render(tt.messages, tt.tools, nil) + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/model/renderers/nemotron3nano.go b/model/renderers/nemotron3nano.go new file mode 100644 index 00000000..478a59bd --- /dev/null +++ b/model/renderers/nemotron3nano.go @@ -0,0 +1,220 @@ +package renderers + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/ollama/ollama/api" +) + +type Nemotron3NanoRenderer struct{} + +func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) { + var sb strings.Builder + + // thinking is enabled if user requests it + enableThinking := thinkValue != nil && thinkValue.Bool() + + // Extract system message if present + var systemMessage string + var loopMessages []api.Message + if len(messages) > 0 && messages[0].Role == "system" { + systemMessage = messages[0].Content + loopMessages = messages[1:] + } else { + loopMessages = messages + } + + // Find last user message index for thinking truncation + lastUserIdx := -1 + for i, msg := range loopMessages { + if msg.Role == "user" { + lastUserIdx = i + } + } + + sb.WriteString("<|im_start|>system\n") + if systemMessage != "" { + sb.WriteString(systemMessage) + } + + if len(tools) > 0 { + if systemMessage != "" { + sb.WriteString("\n\n") + } + sb.WriteString(r.renderTools(tools)) + } + sb.WriteString("<|im_end|>\n") + + for i, message := range loopMessages { + switch message.Role { + case "assistant": + // Build content with thinking tags + content := r.buildContent(message) + shouldTruncate := i < lastUserIdx + + if len(message.ToolCalls) > 0 { + sb.WriteString("<|im_start|>assistant\n") + sb.WriteString(r.formatContent(content, shouldTruncate, true)) + r.writeToolCalls(&sb, message.ToolCalls) + sb.WriteString("<|im_end|>\n") + } else { + formatted := r.formatContent(content, shouldTruncate, false) + sb.WriteString("<|im_start|>assistant\n" + formatted + "<|im_end|>\n") + } + + case "user", "system": + sb.WriteString("<|im_start|>" + message.Role + "\n") + sb.WriteString(message.Content) + sb.WriteString("<|im_end|>\n") + + case "tool": + // Check if previous message was also a tool message + prevWasTool := i > 0 && loopMessages[i-1].Role == "tool" + nextIsTool := i+1 < len(loopMessages) && loopMessages[i+1].Role == "tool" + + if !prevWasTool { + sb.WriteString("<|im_start|>user\n") + } + sb.WriteString("\n") + sb.WriteString(message.Content) + sb.WriteString("\n\n") + + if !nextIsTool { + sb.WriteString("<|im_end|>\n") + } + + default: + sb.WriteString("<|im_start|>" + message.Role + "\n" + message.Content + "<|im_end|>\n") + } + } + + // Add generation prompt + if enableThinking { + sb.WriteString("<|im_start|>assistant\n\n") + } else { + sb.WriteString("<|im_start|>assistant\n") + } + + return sb.String(), nil +} + +func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string { + var sb strings.Builder + sb.WriteString("# Tools\n\nYou have access to the following functions:\n\n") + + for _, tool := range tools { + fn := tool.Function + sb.WriteString("\n\n" + fn.Name + "") + + if fn.Description != "" { + sb.WriteString("\n" + strings.TrimSpace(fn.Description) + "") + } + + sb.WriteString("\n") + if fn.Parameters.Properties != nil { + for paramName, paramFields := range fn.Parameters.Properties { + sb.WriteString("\n") + sb.WriteString("\n" + paramName + "") + + if len(paramFields.Type) > 0 { + sb.WriteString("\n" + strings.Join(paramFields.Type, ", ") + "") + } + + if paramFields.Description != "" { + sb.WriteString("\n" + strings.TrimSpace(paramFields.Description) + "") + } + + if len(paramFields.Enum) > 0 { + enumJSON, _ := json.Marshal(paramFields.Enum) + sb.WriteString("\n" + string(enumJSON) + "") + } + + sb.WriteString("\n") + } + } + + if len(fn.Parameters.Required) > 0 { + reqJSON, _ := json.Marshal(fn.Parameters.Required) + sb.WriteString("\n" + string(reqJSON) + "") + } + + sb.WriteString("\n") + sb.WriteString("\n") + } + + sb.WriteString("\n") + + sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n") + + return sb.String() +} + +func (r *Nemotron3NanoRenderer) buildContent(message api.Message) string { + // The parser always extracts thinking into the Thinking field, + // so Content will never have tags embedded + if message.Thinking != "" { + return "\n" + message.Thinking + "\n\n" + message.Content + } + return "" + message.Content +} + +func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, addNewline bool) string { + if content == "" { + return "" + } + + if !truncate { + if addNewline { + return strings.TrimSpace(content) + "\n" + } + return strings.TrimSpace(content) + } + + // Truncate thinking - keep only content after + c := content + if strings.Contains(c, "") { + parts := strings.Split(c, "") + c = parts[len(parts)-1] + } else if strings.Contains(c, "") { + parts := strings.Split(c, "") + c = parts[0] + } + c = "" + strings.TrimSpace(c) + + if addNewline && len(c) > len("") { + return c + "\n" + } + if c == "" { + return c + } + return strings.TrimSpace(c) +} + +func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) { + for _, tc := range toolCalls { + sb.WriteString("\n\n") + for name, value := range tc.Function.Arguments { + sb.WriteString("\n" + r.formatArgValue(value) + "\n\n") + } + sb.WriteString("\n\n") + } +} + +func (r *Nemotron3NanoRenderer) formatArgValue(value any) string { + switch v := value.(type) { + case map[string]any, []any: + jsonBytes, _ := json.Marshal(v) + return string(jsonBytes) + default: + return fmt.Sprintf("%v", v) + } +} diff --git a/model/renderers/nemotron3nano_test.go b/model/renderers/nemotron3nano_test.go new file mode 100644 index 00000000..ca1feb93 --- /dev/null +++ b/model/renderers/nemotron3nano_test.go @@ -0,0 +1,569 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestNemotron3NanoRenderer(t *testing.T) { + tests := []struct { + name string + msgs []api.Message + tools []api.Tool + thinkValue *api.ThinkValue + expected string + }{ + { + name: "basic user message - thinking mode", + msgs: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nHello!<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "basic user message - no thinking", + msgs: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + thinkValue: nil, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nHello!<|im_end|>\n" + + "<|im_start|>assistant\n", + }, + { + name: "with system message", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello!"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + + "<|im_start|>user\nHello!<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "multi-turn conversation", + msgs: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Hello! How can I help?"}, + {Role: "user", Content: "Tell me a joke"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nHi<|im_end|>\n" + + "<|im_start|>assistant\nHello! How can I help?<|im_end|>\n" + + "<|im_start|>user\nTell me a joke<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "with tools", + msgs: []api.Message{ + {Role: "user", Content: "What's the weather in Paris?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"city"}, + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "The city name"}, + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\nget_weather\n" + + "Get the current weather\n" + + "\n" + + "\ncity\nstring\nThe city name\n\n" + + "[\"city\"]\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nWhat's the weather in Paris?<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "tool call with response", + msgs: []api.Message{ + {Role: "user", Content: "What's the weather in Paris?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "Paris"}, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny, 72F"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"city"}, + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}, Description: "The city name"}, + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\nget_weather\n" + + "Get the current weather\n" + + "\n" + + "\ncity\nstring\nThe city name\n\n" + + "[\"city\"]\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nWhat's the weather in Paris?<|im_end|>\n" + + "<|im_start|>assistant\n\n" + + "\n\n\nParis\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\nSunny, 72F\n\n<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "assistant with content and tool call", + msgs: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "Let me check that for you.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "Paris"}, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\nget_weather\n" + + "\n" + + "\ncity\nstring\n\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nWhat's the weather?<|im_end|>\n" + + "<|im_start|>assistant\nLet me check that for you.\n" + + "\n\n\nParis\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\nSunny\n\n<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "thinking in history is truncated", + msgs: []api.Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Hello!", Thinking: "Let me think about this..."}, + {Role: "user", Content: "How are you?"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nHi<|im_end|>\n" + + "<|im_start|>assistant\nHello!<|im_end|>\n" + + "<|im_start|>user\nHow are you?<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "parallel tool calls", + msgs: []api.Message{ + {Role: "user", Content: "Weather in Paris and London?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "Paris"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"city": "London"}, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny"}, + {Role: "tool", Content: "Rainy"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\nget_weather\n" + + "\n" + + "\ncity\nstring\n\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nWeather in Paris and London?<|im_end|>\n" + + "<|im_start|>assistant\n\n" + + "\n\n\nParis\n\n\n\n" + + "\n\n\nLondon\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\nSunny\n\n\nRainy\n\n<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "thinking disabled when user doesn't request it", + msgs: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + thinkValue: nil, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nHello!<|im_end|>\n" + + "<|im_start|>assistant\n", + }, + { + name: "complex message history with thinking, tools, tool calls, tool results and content", + msgs: []api.Message{ + {Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"}, + {Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}}, + {Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}}, + }}, + {Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"}, + {Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"}, + {Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}, + }}, + {Role: "tool", Content: "4", ToolCallID: "call3"}, + {Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "city": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "calculate", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "expression": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\nget_weather\n" + + "\n" + + "\ncity\nstring\n\n" + + "\n\n" + + "\ncalculate\n" + + "\n" + + "\nexpression\nstring\n\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nWhat's the weather in Paris and London? Also, what's 2+2?<|im_end|>\n" + + "<|im_start|>assistant\n" + + "\nI need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.\n\n" + + "\n\n\nParis\n\n\n\n" + + "\n\n\nLondon\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\nSunny, 22°C\n\n\nRainy, 15°C\n\n<|im_end|>\n" + + "<|im_start|>assistant\n" + + "\nNow I have the weather data. Let me calculate 2+2.\n\n" + + "\n\n\n2+2\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\n4\n\n<|im_end|>\n" + + "<|im_start|>assistant\n" + + "\nPerfect! I have all the information needed to provide a complete answer.\n\n" + + "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "empty messages list", + msgs: []api.Message{}, + thinkValue: nil, + expected: "<|im_start|>system\n<|im_end|>\n<|im_start|>assistant\n", + }, + { + name: "tool result with JSON content", + msgs: []api.Message{ + {Role: "user", Content: "Get user info"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "get_user", Arguments: map[string]any{"id": "123"}}}, + }, + }, + {Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_user", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}}, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\nget_user\n\n" + + "\nid\nstring\n\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nGet user info<|im_end|>\n" + + "<|im_start|>assistant\n\n" + + "\n\n\n123\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\n{\"name\": \"John\", \"age\": 30, \"active\": true}\n\n<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "assistant message with only thinking no content", + msgs: []api.Message{ + {Role: "user", Content: "Think about this"}, + {Role: "assistant", Thinking: "Deep thoughts here...", Content: ""}, + {Role: "user", Content: "What did you think?"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nThink about this<|im_end|>\n" + + "<|im_start|>assistant\n<|im_end|>\n" + + "<|im_start|>user\nWhat did you think?<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "tool call with complex nested argument", + msgs: []api.Message{ + {Role: "user", Content: "Create data"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{ + Name: "create", + Arguments: map[string]any{ + "data": map[string]any{"nested": "value", "count": 42}, + }, + }}, + }, + }, + {Role: "tool", Content: "Created"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "create", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}}, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\ncreate\n\n" + + "\ndata\nobject\n\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nCreate data<|im_end|>\n" + + "<|im_start|>assistant\n\n" + + "\n\n\n{\"count\":42,\"nested\":\"value\"}\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\nCreated\n\n<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "content explaining the format itself", + msgs: []api.Message{ + {Role: "user", Content: "How do I format a tool call?"}, + {Role: "assistant", Content: "To call a tool, use tags with inside."}, + {Role: "user", Content: "Thanks!"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n<|im_end|>\n" + + "<|im_start|>user\nHow do I format a tool call?<|im_end|>\n" + + "<|im_start|>assistant\nTo call a tool, use tags with inside.<|im_end|>\n" + + "<|im_start|>user\nThanks!<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + { + name: "unicode in content and tool args", + msgs: []api.Message{ + {Role: "user", Content: "Translate 你好"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "translate", Arguments: map[string]any{"text": "你好"}}}, + }, + }, + {Role: "tool", Content: "Hello"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "translate", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.ToolProperty{ + "text": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: "<|im_start|>system\n" + + "# Tools\n\nYou have access to the following functions:\n\n\n" + + "\ntranslate\n\n" + + "\ntext\nstring\n\n" + + "\n\n\n\n" + + "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n" + + "\n\n\nvalue_1\n\n" + + "\nThis is the value for the second parameter\nthat can span\nmultiple lines\n" + + "\n\n\n\n\nReminder:\n" + + "- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n" + + "- Required parameters MUST be specified\n" + + "- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n" + + "- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n" + + "<|im_end|>\n" + + "<|im_start|>user\nTranslate 你好<|im_end|>\n" + + "<|im_start|>assistant\n\n" + + "\n\n\n你好\n\n\n\n<|im_end|>\n" + + "<|im_start|>user\n\nHello\n\n<|im_end|>\n" + + "<|im_start|>assistant\n\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + renderer := &Nemotron3NanoRenderer{} + rendered, err := renderer.Render(tt.msgs, tt.tools, tt.thinkValue) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} diff --git a/model/renderers/olmo3.go b/model/renderers/olmo3.go index 24ade20d..c6cdaa72 100644 --- a/model/renderers/olmo3.go +++ b/model/renderers/olmo3.go @@ -10,12 +10,15 @@ import ( ) const ( - olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. " - olmo3NoFunctionsMessage = "You do not currently have access to any functions. " - olmo3WithFunctionsMessage = "You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Output any function calls within XML tags. Do not make assumptions about what values to plug into functions." + olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. " + olmo31DefaultSystemMessage = "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai. " + olmo3NoFunctionsMessage = "You do not currently have access to any functions. " + olmo3WithFunctionsMessage = "You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Output any function calls within XML tags. Do not make assumptions about what values to plug into functions." ) -type Olmo3Renderer struct{} +type Olmo3Renderer struct { + UseExtendedSystemMessage bool +} func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) { var sb strings.Builder @@ -51,7 +54,11 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api. } else { // Default system message - single newline after "system" sb.WriteString("<|im_start|>system\n") - sb.WriteString(olmo3DefaultSystemMessage) + if r.UseExtendedSystemMessage { + sb.WriteString(olmo31DefaultSystemMessage) + } else { + sb.WriteString(olmo3DefaultSystemMessage) + } if len(tools) > 0 { functionsJSON, err := marshalWithSpaces(tools) @@ -140,7 +147,7 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api. } if needsGenerationPrompt { - sb.WriteString("<|im_start|>assistant\n\n") + sb.WriteString("<|im_start|>assistant\n") } return sb.String(), nil diff --git a/model/renderers/olmo3_test.go b/model/renderers/olmo3_test.go index 56c79a23..be9c4eac 100644 --- a/model/renderers/olmo3_test.go +++ b/model/renderers/olmo3_test.go @@ -24,7 +24,7 @@ func TestOlmo3Renderer(t *testing.T) { "You are a helpful function-calling AI assistant. You do not currently have access to any functions. <|im_end|>\n" + "<|im_start|>user\n" + "Hello!<|im_end|>\n" + - "<|im_start|>assistant\n\n", + "<|im_start|>assistant\n", }, { name: "with system message no tools", @@ -36,7 +36,7 @@ func TestOlmo3Renderer(t *testing.T) { "You are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + "Hello!<|im_end|>\n" + - "<|im_start|>assistant\n\n", + "<|im_start|>assistant\n", }, { name: "with system message and tools", @@ -64,7 +64,7 @@ func TestOlmo3Renderer(t *testing.T) { `You are a helpful assistant.[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]<|im_end|>` + "\n" + "<|im_start|>user\n" + "What is the weather?<|im_end|>\n" + - "<|im_start|>assistant\n\n", + "<|im_start|>assistant\n", }, { name: "default system with tools - includes function instruction", @@ -93,7 +93,7 @@ func TestOlmo3Renderer(t *testing.T) { `[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]<|im_end|>` + "\n" + "<|im_start|>user\n" + "What is the weather?<|im_end|>\n" + - "<|im_start|>assistant\n\n", + "<|im_start|>assistant\n", }, { name: "assistant with tool calls - function call syntax", @@ -141,7 +141,7 @@ func TestOlmo3Renderer(t *testing.T) { `Let me check the weather.get_weather(location="San Francisco")<|im_end|>` + "\n" + "<|im_start|>environment\n" + `{"temperature": 68}<|im_end|>` + "\n" + - "<|im_start|>assistant\n\n", + "<|im_start|>assistant\n", }, { name: "multi-turn conversation", @@ -159,7 +159,7 @@ func TestOlmo3Renderer(t *testing.T) { "Hi there!<|im_end|>\n" + "<|im_start|>user\n" + "How are you?<|im_end|>\n" + - "<|im_start|>assistant\n\n", + "<|im_start|>assistant\n", }, { name: "parallel tool calls - newline separated", @@ -214,7 +214,7 @@ func TestOlmo3Renderer(t *testing.T) { `{"temperature": 68}<|im_end|>` + "\n" + "<|im_start|>environment\n" + `{"temperature": 55}<|im_end|>` + "\n" + - "<|im_start|>assistant\n\n", + "<|im_start|>assistant\n", }, { name: "tool call with multiple arguments", @@ -259,7 +259,7 @@ func TestOlmo3Renderer(t *testing.T) { "Book a flight<|im_end|>\n" + "<|im_start|>assistant\n" + `book_flight(from="SFO", to="NYC")<|im_end|>` + "\n" + - "<|im_start|>assistant\n\n", + "<|im_start|>assistant\n", }, { name: "assistant prefill - no generation prompt", diff --git a/model/renderers/olmo3_think.go b/model/renderers/olmo3_think.go index b327d044..6fe62110 100644 --- a/model/renderers/olmo3_think.go +++ b/model/renderers/olmo3_think.go @@ -1,31 +1,31 @@ package renderers import ( - "encoding/json" "strings" "github.com/ollama/ollama/api" ) +type Olmo3ThinkVariant int + const ( - olmo3ThinkDefaultSystemMessage = "You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai." - olmo3ThinkNoFunctionsMessage = " You do not currently have access to any functions." + // Olmo3Think32B is for allenai/Olmo-3-32B-Think + Olmo3Think32B Olmo3ThinkVariant = iota + // Olmo31Think is for allenai/Olmo-3-7B-Think and allenai/Olmo-3.1-32B-Think (includes model info) + Olmo31Think ) -type Olmo3ThinkRenderer struct{} +const ( + olmo3ThinkFunctionsSuffix = " You do not currently have access to any functions. " + olmo3Think32BSystemMessage = "You are a helpful AI assistant." + olmo31ThinkSystemMessage = "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai." +) -type olmo3ThinkToolCall struct { - ID string `json:"id,omitempty"` - Type string `json:"type,omitempty"` - Function olmo3ThinkToolCallFunc `json:"function"` +type Olmo3ThinkRenderer struct { + Variant Olmo3ThinkVariant } -type olmo3ThinkToolCallFunc struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) { +func (r *Olmo3ThinkRenderer) Render(messages []api.Message, _ []api.Tool, _ *api.ThinkValue) (string, error) { var sb strings.Builder var systemMessage *api.Message @@ -37,34 +37,31 @@ func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _ } continue } + // Skip tool messages - Think models don't support tools + if message.Role == "tool" { + continue + } filteredMessages = append(filteredMessages, message) } - systemContent := olmo3ThinkDefaultSystemMessage - if systemMessage != nil { - systemContent = systemMessage.Content - } - sb.WriteString("<|im_start|>system\n") - sb.WriteString(systemContent) - if len(tools) > 0 { - functionsJSON, err := marshalWithSpaces(tools) - if err != nil { - return "", err - } - sb.WriteString(" ") - sb.WriteString(string(functionsJSON)) - sb.WriteString("") + if systemMessage != nil { + sb.WriteString(systemMessage.Content) + sb.WriteString(olmo3ThinkFunctionsSuffix) } else { - sb.WriteString(olmo3ThinkNoFunctionsMessage) - sb.WriteString(" ") + // Default system message varies by variant + switch r.Variant { + case Olmo3Think32B: + sb.WriteString(olmo3Think32BSystemMessage) + default: // Olmo3Think7B, Olmo31Think use same template - diverges from HF but confirmed difference from team + sb.WriteString(olmo31ThinkSystemMessage) + } } + sb.WriteString("<|im_end|>\n") - for i, message := range filteredMessages { - lastMessage := i == len(filteredMessages)-1 - + for _, message := range filteredMessages { switch message.Role { case "user": sb.WriteString("<|im_start|>user\n") @@ -73,58 +70,15 @@ func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _ case "assistant": sb.WriteString("<|im_start|>assistant\n") - if message.Content != "" { sb.WriteString(message.Content) } - - if len(message.ToolCalls) > 0 { - toolCalls := make([]olmo3ThinkToolCall, len(message.ToolCalls)) - for j, tc := range message.ToolCalls { - argsJSON, err := json.Marshal(tc.Function.Arguments) - if err != nil { - return "", err - } - toolCalls[j] = olmo3ThinkToolCall{ - ID: tc.ID, - Type: "function", - Function: olmo3ThinkToolCallFunc{ - Name: tc.Function.Name, - Arguments: string(argsJSON), - }, - } - } - toolCallsJSON, err := marshalWithSpaces(toolCalls) - if err != nil { - return "", err - } - sb.WriteString("") - sb.WriteString(string(toolCallsJSON)) - sb.WriteString("") - } - - if !lastMessage { - sb.WriteString("<|im_end|>\n") - } - - case "tool": - sb.WriteString("<|im_start|>environment\n") - sb.WriteString(message.Content) sb.WriteString("<|im_end|>\n") } } - needsGenerationPrompt := true - if len(filteredMessages) > 0 { - lastMsg := filteredMessages[len(filteredMessages)-1] - if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" { - needsGenerationPrompt = false - } - } - - if needsGenerationPrompt { - sb.WriteString("<|im_start|>assistant\n") - } + // Always add generation prompt with tag for thinking models + sb.WriteString("<|im_start|>assistant\n") return sb.String(), nil } diff --git a/model/renderers/olmo3_think_test.go b/model/renderers/olmo3_think_test.go index 21e333e3..8bfd5fdc 100644 --- a/model/renderers/olmo3_think_test.go +++ b/model/renderers/olmo3_think_test.go @@ -11,24 +11,27 @@ import ( func TestOlmo3ThinkRenderer(t *testing.T) { tests := []struct { name string + variant Olmo3ThinkVariant msgs []api.Message tools []api.Tool expected string }{ { - name: "basic without system - adds default system", + name: "7b_basic_without_system", + variant: Olmo31Think, msgs: []api.Message{ {Role: "user", Content: "Hello!"}, }, expected: "<|im_start|>system\n" + - "You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <|im_end|>\n" + + "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" + "<|im_start|>user\n" + "Hello!<|im_end|>\n" + "<|im_start|>assistant\n" + "", }, { - name: "with system message no tools", + name: "7b_with_custom_system", + variant: Olmo31Think, msgs: []api.Message{ {Role: "system", Content: "You are a helpful assistant."}, {Role: "user", Content: "Hello!"}, @@ -41,9 +44,9 @@ func TestOlmo3ThinkRenderer(t *testing.T) { "", }, { - name: "with system message and tools", + name: "7b_tools_ignored", + variant: Olmo31Think, msgs: []api.Message{ - {Role: "system", Content: "You are a helpful assistant."}, {Role: "user", Content: "What is the weather?"}, }, tools: []api.Tool{ @@ -52,27 +55,20 @@ func TestOlmo3ThinkRenderer(t *testing.T) { Function: api.ToolFunction{ Name: "get_weather", Description: "Get the current weather", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ - "location": {Type: api.PropertyType{"string"}, Description: "The city"}, - }, - }, }, }, }, expected: "<|im_start|>system\n" + - `You are a helpful assistant. [{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]<|im_end|>` + "\n" + + "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" + "<|im_start|>user\n" + "What is the weather?<|im_end|>\n" + "<|im_start|>assistant\n" + "", }, { - name: "assistant with tool calls", + name: "7b_tool_calls_and_tool_messages_ignored", + variant: Olmo31Think, msgs: []api.Message{ - {Role: "system", Content: "You are a helpful assistant."}, {Role: "user", Content: "What is the weather in SF?"}, { Role: "assistant", @@ -81,53 +77,33 @@ func TestOlmo3ThinkRenderer(t *testing.T) { { ID: "call_1", Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: map[string]any{ - "location": "San Francisco", - }, - }, - }, - }, - }, - {Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_weather", - Description: "Get the current weather", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ - "location": {Type: api.PropertyType{"string"}, Description: "The city"}, + Name: "get_weather", + Arguments: map[string]any{"location": "San Francisco"}, }, }, }, }, + {Role: "tool", Content: `{"temperature": 68}`}, }, expected: "<|im_start|>system\n" + - `You are a helpful assistant. [{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]<|im_end|>` + "\n" + + "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" + "<|im_start|>user\n" + "What is the weather in SF?<|im_end|>\n" + "<|im_start|>assistant\n" + - `Let me check the weather.[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}]<|im_end|>` + "\n" + - "<|im_start|>environment\n" + - `{"temperature": 68}<|im_end|>` + "\n" + + "Let me check the weather.<|im_end|>\n" + "<|im_start|>assistant\n" + "", }, { - name: "multi-turn conversation", + name: "7b_multi_turn_conversation", + variant: Olmo31Think, msgs: []api.Message{ - {Role: "system", Content: "You are a helpful assistant."}, {Role: "user", Content: "Hello"}, {Role: "assistant", Content: "Hi there!"}, {Role: "user", Content: "How are you?"}, }, expected: "<|im_start|>system\n" + - "You are a helpful assistant. You do not currently have access to any functions. <|im_end|>\n" + + "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" + "<|im_start|>user\n" + "Hello<|im_end|>\n" + "<|im_start|>assistant\n" + @@ -138,73 +114,56 @@ func TestOlmo3ThinkRenderer(t *testing.T) { "", }, { - name: "parallel tool calls", + name: "32b_basic_without_system", + variant: Olmo3Think32B, msgs: []api.Message{ - {Role: "user", Content: "Get weather in SF and NYC"}, - { - Role: "assistant", - ToolCalls: []api.ToolCall{ - { - ID: "call_1", - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: map[string]any{"location": "San Francisco"}, - }, - }, - { - ID: "call_2", - Function: api.ToolCallFunction{ - Name: "get_weather", - Arguments: map[string]any{"location": "New York"}, - }, - }, - }, - }, - {Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"}, - {Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"}, - }, - tools: []api.Tool{ - { - Type: "function", - Function: api.ToolFunction{ - Name: "get_weather", - Parameters: api.ToolFunctionParameters{ - Type: "object", - Properties: map[string]api.ToolProperty{ - "location": {Type: api.PropertyType{"string"}}, - }, - }, - }, - }, + {Role: "user", Content: "Hello!"}, }, expected: "<|im_start|>system\n" + - `You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. [{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]<|im_end|>` + "\n" + + "You are a helpful AI assistant.<|im_end|>\n" + "<|im_start|>user\n" + - "Get weather in SF and NYC<|im_end|>\n" + - "<|im_start|>assistant\n" + - `[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}, {"id": "call_2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"New York\"}"}}]<|im_end|>` + "\n" + - "<|im_start|>environment\n" + - `{"temperature": 68}<|im_end|>` + "\n" + - "<|im_start|>environment\n" + - `{"temperature": 55}<|im_end|>` + "\n" + + "Hello!<|im_end|>\n" + "<|im_start|>assistant\n" + "", }, { - name: "assistant message only content no tool calls", + name: "32b_with_custom_system_gets_suffix", + variant: Olmo3Think32B, msgs: []api.Message{ - {Role: "user", Content: "Tell me a joke"}, - {Role: "assistant", Content: "Why did the chicken cross the road?"}, - {Role: "user", Content: "I don't know, why?"}, + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello!"}, }, expected: "<|im_start|>system\n" + - "You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <|im_end|>\n" + + "You are a helpful assistant. You do not currently have access to any functions. <|im_end|>\n" + "<|im_start|>user\n" + - "Tell me a joke<|im_end|>\n" + + "Hello!<|im_end|>\n" + "<|im_start|>assistant\n" + - "Why did the chicken cross the road?<|im_end|>\n" + + "", + }, + { + name: "31_basic_without_system", + variant: Olmo31Think, + msgs: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + expected: "<|im_start|>system\n" + + "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" + "<|im_start|>user\n" + - "I don't know, why?<|im_end|>\n" + + "Hello!<|im_end|>\n" + + "<|im_start|>assistant\n" + + "", + }, + { + name: "31_with_custom_system_gets_suffix", + variant: Olmo31Think, + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello!"}, + }, + expected: "<|im_start|>system\n" + + "You are a helpful assistant. You do not currently have access to any functions. <|im_end|>\n" + + "<|im_start|>user\n" + + "Hello!<|im_end|>\n" + "<|im_start|>assistant\n" + "", }, @@ -212,7 +171,7 @@ func TestOlmo3ThinkRenderer(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rendered, err := (&Olmo3ThinkRenderer{}).Render(tt.msgs, tt.tools, nil) + rendered, err := (&Olmo3ThinkRenderer{Variant: tt.variant}).Render(tt.msgs, tt.tools, nil) if err != nil { t.Fatal(err) } diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index 66c2f8de..2aed5dca 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -59,12 +59,27 @@ func rendererForName(name string) Renderer { case "cogito": renderer := &CogitoRenderer{isThinking: true} return renderer + case "deepseek3.1": + renderer := &DeepSeek3Renderer{IsThinking: true, Variant: Deepseek31} + return renderer case "olmo3": - renderer := &Olmo3Renderer{} + renderer := &Olmo3Renderer{UseExtendedSystemMessage: false} + return renderer + case "olmo3.1": + renderer := &Olmo3Renderer{UseExtendedSystemMessage: true} return renderer case "olmo3-think": - renderer := &Olmo3ThinkRenderer{} + // Used for Olmo-3-7B-Think and Olmo-3.1-32B-Think (same template) + renderer := &Olmo3ThinkRenderer{Variant: Olmo31Think} return renderer + case "olmo3-32b-think": + // Used for Olmo-3-32B-Think + renderer := &Olmo3ThinkRenderer{Variant: Olmo3Think32B} + return renderer + case "nemotron-3-nano": + return &Nemotron3NanoRenderer{} + case "functiongemma": + return &FunctionGemmaRenderer{} default: return nil } diff --git a/parser/parser.go b/parser/parser.go index 1f476444..5ef918bf 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -17,6 +17,7 @@ import ( "strings" "sync" + "golang.org/x/mod/semver" "golang.org/x/sync/errgroup" "golang.org/x/text/encoding/unicode" "golang.org/x/text/transform" @@ -104,6 +105,16 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) req.Renderer = c.Args case "parser": req.Parser = c.Args + case "requires": + // golang.org/x/mod/semver requires "v" prefix + requires := c.Args + if !strings.HasPrefix(requires, "v") { + requires = "v" + requires + } + if !semver.IsValid(requires) { + return nil, fmt.Errorf("requires must be a valid semver (e.g. 0.14.0)") + } + req.Requires = strings.TrimPrefix(requires, "v") case "message": role, msg, _ := strings.Cut(c.Args, ": ") messages = append(messages, api.Message{Role: role, Content: msg}) @@ -322,7 +333,7 @@ func (c Command) String() string { switch c.Name { case "model": fmt.Fprintf(&sb, "FROM %s", c.Args) - case "license", "template", "system", "adapter", "renderer", "parser": + case "license", "template", "system", "adapter", "renderer", "parser", "requires": fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) case "message": role, message, _ := strings.Cut(c.Args, ": ") @@ -348,7 +359,7 @@ const ( var ( errMissingFrom = errors.New("no FROM line") errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") - errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"") + errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"") ) type ParserError struct { @@ -608,7 +619,7 @@ func isValidMessageRole(role string) bool { func isValidCommand(cmd string) bool { switch strings.ToLower(cmd) { - case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message": + case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires": return true default: return false diff --git a/server/create.go b/server/create.go index 4fdf4104..15e364e1 100644 --- a/server/create.go +++ b/server/create.go @@ -42,10 +42,10 @@ var ( ) func (s *Server) CreateHandler(c *gin.Context) { - config := &ConfigV2{ + config := &model.ConfigV2{ OS: "linux", Architecture: "amd64", - RootFS: RootFS{ + RootFS: model.RootFS{ Type: "layers", }, } @@ -61,6 +61,7 @@ func (s *Server) CreateHandler(c *gin.Context) { config.Renderer = r.Renderer config.Parser = r.Parser + config.Requires = r.Requires for v := range r.Files { if !fs.ValidPath(v) { @@ -120,13 +121,13 @@ func (s *Server) CreateHandler(c *gin.Context) { ch <- gin.H{"error": err.Error()} } - if err == nil && !remote && (config.Renderer == "" || config.Parser == "") { + if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") { manifest, mErr := ParseNamedManifest(fromName) if mErr == nil && manifest.Config.Digest != "" { configPath, pErr := GetBlobsPath(manifest.Config.Digest) if pErr == nil { if cfgFile, fErr := os.Open(configPath); fErr == nil { - var baseConfig ConfigV2 + var baseConfig model.ConfigV2 if decErr := json.NewDecoder(cfgFile).Decode(&baseConfig); decErr == nil { if config.Renderer == "" { config.Renderer = baseConfig.Renderer @@ -134,6 +135,9 @@ func (s *Server) CreateHandler(c *gin.Context) { if config.Parser == "" { config.Parser = baseConfig.Parser } + if config.Requires == "" { + config.Requires = baseConfig.Requires + } } cfgFile.Close() } @@ -459,7 +463,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) { return ggml.KV{}, fmt.Errorf("no base model was found") } -func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) { +func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) { var layers []Layer for _, layer := range baseLayers { if layer.GGML != nil { @@ -789,7 +793,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) { return layers, nil } -func createConfigLayer(layers []Layer, config ConfigV2) (*Layer, error) { +func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) { digests := make([]string, len(layers)) for i, layer := range layers { digests[i] = layer.Digest diff --git a/server/images.go b/server/images.go index d3bd9ffa..951f7ac6 100644 --- a/server/images.go +++ b/server/images.go @@ -54,7 +54,7 @@ type registryOptions struct { type Model struct { Name string `json:"name"` - Config ConfigV2 + Config model.ConfigV2 ShortName string ModelPath string ParentModel string @@ -266,35 +266,6 @@ func (m *Model) String() string { return modelfile.String() } -type ConfigV2 struct { - ModelFormat string `json:"model_format"` - ModelFamily string `json:"model_family"` - ModelFamilies []string `json:"model_families"` - ModelType string `json:"model_type"` // shown as Parameter Size - FileType string `json:"file_type"` // shown as Quantization Level - Renderer string `json:"renderer,omitempty"` - Parser string `json:"parser,omitempty"` - - RemoteHost string `json:"remote_host,omitempty"` - RemoteModel string `json:"remote_model,omitempty"` - - // used for remotes - Capabilities []string `json:"capabilities,omitempty"` - ContextLen int `json:"context_length,omitempty"` - EmbedLen int `json:"embedding_length,omitempty"` - BaseName string `json:"base_name,omitempty"` - - // required by spec - Architecture string `json:"architecture"` - OS string `json:"os"` - RootFS RootFS `json:"rootfs"` -} - -type RootFS struct { - Type string `json:"type"` - DiffIDs []string `json:"diff_ids"` -} - func GetManifest(mp ModelPath) (*Manifest, string, error) { fp, err := mp.GetManifestPath() if err != nil { diff --git a/server/routes.go b/server/routes.go index 54f23d5d..b19a40fb 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1106,6 +1106,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { Messages: msgs, Capabilities: m.Capabilities(), ModifiedAt: manifest.fi.ModTime(), + Requires: m.Config.Requires, } if m.Config.RemoteHost != "" { @@ -1223,7 +1224,7 @@ func (s *Server) ListHandler(c *gin.Context) { models := []api.ListModelResponse{} for n, m := range ms { - var cf ConfigV2 + var cf model.ConfigV2 if m.Config.Digest != "" { f, err := m.Config.Open() diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 909ebfe5..b1b1a288 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -241,7 +241,7 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) { } defer cfgFile.Close() - var cfg ConfigV2 + var cfg model.ConfigV2 if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil { t.Fatalf("decode config: %v", err) } diff --git a/server/routes_delete_test.go b/server/routes_delete_test.go index 1aef9f14..eb8c4432 100644 --- a/server/routes_delete_test.go +++ b/server/routes_delete_test.go @@ -89,7 +89,7 @@ func TestDeleteDuplicateLayers(t *testing.T) { n := model.ParseName("test") var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(&ConfigV2{}); err != nil { + if err := json.NewEncoder(&b).Encode(&model.ConfigV2{}); err != nil { t.Fatal(err) } diff --git a/server/routes_test.go b/server/routes_test.go index bb7e2b7c..39d4f290 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -126,10 +126,10 @@ func TestRoutes(t *testing.T) { t.Fatalf("failed to create model: %v", err) } - config := &ConfigV2{ + config := &model.ConfigV2{ OS: "linux", Architecture: "amd64", - RootFS: RootFS{ + RootFS: model.RootFS{ Type: "layers", }, } @@ -775,7 +775,7 @@ func TestFilterThinkTags(t *testing.T) { {Role: "user", Content: "What is the answer?"}, }, model: &Model{ - Config: ConfigV2{ + Config: model.ConfigV2{ ModelFamily: "qwen3", }, }, @@ -793,7 +793,7 @@ func TestFilterThinkTags(t *testing.T) { {Role: "user", Content: "What is the answer?"}, }, model: &Model{ - Config: ConfigV2{ + Config: model.ConfigV2{ ModelFamily: "qwen3", }, }, @@ -815,7 +815,7 @@ func TestFilterThinkTags(t *testing.T) { {Role: "assistant", Content: "thinking yet againhjk"}, }, model: &Model{ - Config: ConfigV2{ + Config: model.ConfigV2{ ModelFamily: "qwen3", }, }, @@ -833,7 +833,7 @@ func TestFilterThinkTags(t *testing.T) { {Role: "user", Content: "What is the answer?"}, }, model: &Model{ - Config: ConfigV2{ + Config: model.ConfigV2{ ModelFamily: "llama3", }, }, @@ -853,7 +853,7 @@ func TestFilterThinkTags(t *testing.T) { model: &Model{ Name: "registry.ollama.ai/library/deepseek-r1:latest", ShortName: "deepseek-r1:7b", - Config: ConfigV2{}, + Config: model.ConfigV2{}, }, }, } diff --git a/types/model/config.go b/types/model/config.go new file mode 100644 index 00000000..96aec8cb --- /dev/null +++ b/types/model/config.go @@ -0,0 +1,33 @@ +package model + +// ConfigV2 represents the configuration metadata for a model. +type ConfigV2 struct { + ModelFormat string `json:"model_format"` + ModelFamily string `json:"model_family"` + ModelFamilies []string `json:"model_families"` + ModelType string `json:"model_type"` // shown as Parameter Size + FileType string `json:"file_type"` // shown as Quantization Level + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` + Requires string `json:"requires,omitempty"` + + RemoteHost string `json:"remote_host,omitempty"` + RemoteModel string `json:"remote_model,omitempty"` + + // used for remotes + Capabilities []string `json:"capabilities,omitempty"` + ContextLen int `json:"context_length,omitempty"` + EmbedLen int `json:"embedding_length,omitempty"` + BaseName string `json:"base_name,omitempty"` + + // required by spec + Architecture string `json:"architecture"` + OS string `json:"os"` + RootFS RootFS `json:"rootfs"` +} + +// RootFS represents the root filesystem configuration for a model. +type RootFS struct { + Type string `json:"type"` + DiffIDs []string `json:"diff_ids"` +}