diff --git a/README.md b/README.md index fddb7421..3ce8c3ec 100644 --- a/README.md +++ b/README.md @@ -414,6 +414,8 @@ See the [API documentation](./docs/api.md) for all endpoints. - [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool) - [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration) - [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) +- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance) +- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history ### Cloud diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index f21a8f50..41b03e1b 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -757,3 +757,132 @@ func TestCreateHandler(t *testing.T) { }) } } + +func TestNewCreateRequest(t *testing.T) { + tests := []struct { + name string + from string + opts runOptions + expected *api.CreateRequest + }{ + { + "basic test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "", + Prompt: "You are a fun AI agent", + Messages: []api.Message{}, + WordWrap: true, + }, + &api.CreateRequest{ + From: "mymodel", + Model: "newmodel", + }, + }, + { + "parent model test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "parentmodel", + Messages: []api.Message{}, + WordWrap: true, + }, + &api.CreateRequest{ + From: "parentmodel", + Model: "newmodel", + }, + }, + { + "parent model as filepath test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "/some/file/like/etc/passwd", + Messages: []api.Message{}, + WordWrap: true, + }, + &api.CreateRequest{ + From: "mymodel", + Model: "newmodel", + }, + }, + { + "parent model as windows filepath test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "D:\\some\\file\\like\\etc\\passwd", + Messages: []api.Message{}, + WordWrap: true, + }, + &api.CreateRequest{ + From: "mymodel", + Model: "newmodel", + }, + }, + { + "options test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "parentmodel", + Options: map[string]any{ + "temperature": 1.0, + }, + }, + &api.CreateRequest{ + From: "parentmodel", + Model: "newmodel", + Parameters: map[string]any{ + "temperature": 1.0, + }, + }, + }, + { + "messages test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "parentmodel", + System: "You are a fun AI agent", + Messages: []api.Message{ + { + Role: "user", + Content: "hello there!", + }, + { + Role: "assistant", + Content: "hello to you!", + }, + }, + WordWrap: true, + }, + &api.CreateRequest{ + From: "parentmodel", + Model: "newmodel", + System: "You are a fun AI agent", + Messages: []api.Message{ + { + Role: "user", + Content: "hello there!", + }, + { + Role: "assistant", + Content: "hello to you!", + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := NewCreateRequest(tt.from, tt.opts) + if !cmp.Equal(actual, tt.expected) { + t.Errorf("expected output %#v, got %#v", tt.expected, actual) + } + }) + } +} diff --git a/cmd/interactive.go b/cmd/interactive.go index f3489b65..d85510d4 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -18,6 +18,7 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/readline" "github.com/ollama/ollama/types/errtypes" + "github.com/ollama/ollama/types/model" ) type MultilineState int @@ -459,9 +460,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } func NewCreateRequest(name string, opts runOptions) *api.CreateRequest { + parentModel := opts.ParentModel + + modelName := model.ParseName(parentModel) + if !modelName.IsValid() { + parentModel = "" + } + req := &api.CreateRequest{ - Name: name, - From: cmp.Or(opts.ParentModel, opts.Model), + Model: name, + From: cmp.Or(parentModel, opts.Model), } if opts.System != "" { diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index c7b56890..fbbd9d5c 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -66,6 +66,35 @@ func TestIntegrationMllama(t *testing.T) { DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second) } +func TestIntegrationSplitBatch(t *testing.T) { + image, err := base64.StdEncoding.DecodeString(imageEncoding) + require.NoError(t, err) + req := api.GenerateRequest{ + Model: "gemma3:4b", + // Fill up a chunk of the batch so the image will partially spill over into the next one + System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.", + Prompt: "what does the text in this image say?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + Images: []api.ImageData{ + image, + }, + } + + // Note: sometimes it returns "the ollamas" sometimes "the ollams" + resp := "the ollam" + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + require.NoError(t, PullIfMissing(ctx, client, req.Model)) + // llava models on CPU can be quite slow to start, + DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second) +} + const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6 diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index b6f20286..b443fcd3 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -37,6 +37,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MINICPM3, "minicpm3" }, { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, @@ -804,6 +805,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GEMMA3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index ec742224..aad92a5d 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -41,6 +41,7 @@ enum llm_arch { LLM_ARCH_MINICPM3, LLM_ARCH_GEMMA, LLM_ARCH_GEMMA2, + LLM_ARCH_GEMMA3, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index ab1a07d1..70183041 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GEMMA3: + { + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GEMMA3: + { + } break; case LLM_ARCH_STARCODER2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_PHIMOE: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: + case LLM_ARCH_GEMMA3: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp index 6eb1da08..d2f3a510 100644 --- a/llama/llama.cpp/src/llama-quant.cpp +++ b/llama/llama.cpp/src/llama-quant.cpp @@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // This used to be a regex, but has an extreme cost to compile times. bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + // don't quantize vision stuff + quantize &= name.find("v.blk.") == std::string::npos; + + quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos; + quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos; + quantize &= name.find("v.patch_embedding.weight") == std::string::npos; + quantize &= name.find("v.position_embedding.weight") == std::string::npos; + quantize &= name.find("v.post_layernorm.weight") == std::string::npos; + // quantize only 2D and 3D tensors (experts) quantize &= (ggml_n_dims(tensor) >= 2); diff --git a/llama/patches/0021-gemma3-quantization.patch b/llama/patches/0021-gemma3-quantization.patch new file mode 100644 index 00000000..4f6dbc11 --- /dev/null +++ b/llama/patches/0021-gemma3-quantization.patch @@ -0,0 +1,113 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Patrick Devine +Date: Fri, 14 Mar 2025 16:33:23 -0700 +Subject: [PATCH] gemma3 quantization + +--- + src/llama-arch.cpp | 19 +++++++++++++++++++ + src/llama-arch.h | 1 + + src/llama-model.cpp | 7 +++++++ + src/llama-quant.cpp | 9 +++++++++ + 4 files changed, 36 insertions(+) + +diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp +index b6f20286..b443fcd3 100644 +--- a/src/llama-arch.cpp ++++ b/src/llama-arch.cpp +@@ -37,6 +37,7 @@ static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_MINICPM3, "minicpm3" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, ++ { LLM_ARCH_GEMMA3, "gemma3" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_XVERSE, "xverse" }, +@@ -804,6 +805,24 @@ static const std::map> LLM_TENSOR_N + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, ++ { ++ LLM_ARCH_GEMMA3, ++ { ++ { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, ++ { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, ++ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, ++ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, ++ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, ++ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, ++ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, ++ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, ++ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, ++ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, ++ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, ++ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, ++ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, ++ }, ++ }, + { + LLM_ARCH_STARCODER2, + { +diff --git a/src/llama-arch.h b/src/llama-arch.h +index ec742224..aad92a5d 100644 +--- a/src/llama-arch.h ++++ b/src/llama-arch.h +@@ -41,6 +41,7 @@ enum llm_arch { + LLM_ARCH_MINICPM3, + LLM_ARCH_GEMMA, + LLM_ARCH_GEMMA2, ++ LLM_ARCH_GEMMA3, + LLM_ARCH_STARCODER2, + LLM_ARCH_MAMBA, + LLM_ARCH_XVERSE, +diff --git a/src/llama-model.cpp b/src/llama-model.cpp +index ab1a07d1..70183041 100644 +--- a/src/llama-model.cpp ++++ b/src/llama-model.cpp +@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { + default: type = LLM_TYPE_UNKNOWN; + } + } break; ++ case LLM_ARCH_GEMMA3: ++ { ++ } break; + case LLM_ARCH_STARCODER2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); +@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; ++ case LLM_ARCH_GEMMA3: ++ { ++ } break; + case LLM_ARCH_STARCODER2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); +@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { + case LLM_ARCH_PHIMOE: + case LLM_ARCH_GEMMA: + case LLM_ARCH_GEMMA2: ++ case LLM_ARCH_GEMMA3: + case LLM_ARCH_STARCODER2: + case LLM_ARCH_OPENELM: + case LLM_ARCH_GPTNEOX: +diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp +index 6eb1da08..d2f3a510 100644 +--- a/src/llama-quant.cpp ++++ b/src/llama-quant.cpp +@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: + // This used to be a regex, but has an extreme cost to compile times. + bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + ++ // don't quantize vision stuff ++ quantize &= name.find("v.blk.") == std::string::npos; ++ ++ quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos; ++ quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos; ++ quantize &= name.find("v.patch_embedding.weight") == std::string::npos; ++ quantize &= name.find("v.position_embedding.weight") == std::string::npos; ++ quantize &= name.find("v.post_layernorm.weight") == std::string::npos; ++ + // quantize only 2D and 3D tensors (experts) + quantize &= (ggml_n_dims(tensor) >= 2); + diff --git a/llm/server.go b/llm/server.go index c6f11712..adc11aae 100644 --- a/llm/server.go +++ b/llm/server.go @@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal) } - slog.Info("starting llama server", "cmd", s.cmd.String()) + slog.Info("starting llama server", "cmd", s.cmd) if envconfig.Debug() { filteredEnv := []string{} for _, ev := range s.cmd.Env { @@ -470,7 +470,7 @@ const ( // iota is reset to 0 ServerStatusError ) -func (s ServerStatus) ToString() string { +func (s ServerStatus) String() string { switch s { case ServerStatusReady: return "llm server ready" @@ -485,12 +485,9 @@ func (s ServerStatus) ToString() string { } } -type ServerStatusResp struct { - Status string `json:"status"` - SlotsIdle int `json:"slots_idle"` - SlotsProcessing int `json:"slots_processing"` - Error string `json:"error"` - Progress float32 `json:"progress"` +type ServerStatusResponse struct { + Status ServerStatus `json:"status"` + Progress float32 `json:"progress"` } func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { @@ -502,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { } if s.cmd.ProcessState.ExitCode() == -1 { // Most likely a signal killed it, log some more details to try to help troubleshoot - slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String()) + slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState) } return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) } @@ -527,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { return ServerStatusError, fmt.Errorf("read health request: %w", err) } - var status ServerStatusResp - if err := json.Unmarshal(body, &status); err != nil { + var ssr ServerStatusResponse + if err := json.Unmarshal(body, &ssr); err != nil { return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err) } - switch status.Status { - case "ok": - return ServerStatusReady, nil - case "no slot available": - return ServerStatusNoSlotsAvailable, nil - case "loading model": - s.loadProgress = status.Progress - return ServerStatusLoadingModel, nil + switch ssr.Status { + case ServerStatusLoadingModel: + s.loadProgress = ssr.Progress + return ssr.Status, nil + case ServerStatusReady, ServerStatusNoSlotsAvailable: + return ssr.Status, nil default: - return ServerStatusError, fmt.Errorf("server error: %+v", status) + return ssr.Status, fmt.Errorf("server error: %+v", ssr) } } @@ -616,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { status, _ := s.getServerStatus(ctx) if lastStatus != status && status != ServerStatusReady { // Only log on status changes - slog.Info("waiting for server to become available", "status", status.ToString()) + slog.Info("waiting for server to become available", "status", status) } switch status { case ServerStatusReady: @@ -630,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress)) stallTimer = time.Now().Add(stallDuration) } else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 { - slog.Debug("model load completed, waiting for server to become available", "status", status.ToString()) + slog.Debug("model load completed, waiting for server to become available", "status", status) stallTimer = time.Now().Add(stallDuration) fullyLoaded = true } @@ -671,63 +666,26 @@ type ImageData struct { AspectRatioID int `json:"aspect_ratio_id"` } -type completion struct { - Content string `json:"content"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Stop bool `json:"stop"` - StoppedLimit bool `json:"stopped_limit"` - - Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` - } -} - type CompletionRequest struct { Prompt string Format json.RawMessage Images []ImageData Options *api.Options + + Grammar string // set before sending the request to the subprocess } type CompletionResponse struct { - Content string - DoneReason string - Done bool - PromptEvalCount int - PromptEvalDuration time.Duration - EvalCount int - EvalDuration time.Duration + Content string `json:"content"` + DoneReason string `json:"done_reason"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration time.Duration `json:"eval_duration"` } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { - request := map[string]any{ - "prompt": req.Prompt, - "stream": true, - "n_predict": req.Options.NumPredict, - "n_keep": req.Options.NumKeep, - "main_gpu": req.Options.MainGPU, - "temperature": req.Options.Temperature, - "top_k": req.Options.TopK, - "top_p": req.Options.TopP, - "min_p": req.Options.MinP, - "typical_p": req.Options.TypicalP, - "repeat_last_n": req.Options.RepeatLastN, - "repeat_penalty": req.Options.RepeatPenalty, - "presence_penalty": req.Options.PresencePenalty, - "frequency_penalty": req.Options.FrequencyPenalty, - "mirostat": req.Options.Mirostat, - "mirostat_tau": req.Options.MirostatTau, - "mirostat_eta": req.Options.MirostatEta, - "seed": req.Options.Seed, - "stop": req.Options.Stop, - "image_data": req.Images, - "cache_prompt": true, - } - if len(req.Format) > 0 { switch string(req.Format) { case `null`, `""`: @@ -735,7 +693,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu // these as "not set". break case `"json"`: - request["grammar"] = grammarJSON + req.Grammar = grammarJSON default: if req.Format[0] != '{' { return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) @@ -746,10 +704,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if g == nil { return fmt.Errorf("invalid JSON schema in format") } - request["grammar"] = string(g) + req.Grammar = string(g) } } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + if err := s.sem.Acquire(ctx, 1); err != nil { if errors.Is(err, context.Canceled) { slog.Info("aborting completion request due to client closing the connection") @@ -770,7 +733,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if err != nil { return err } else if status != ServerStatusReady { - return fmt.Errorf("unexpected server status: %s", status.ToString()) + return fmt.Errorf("unexpected server status: %s", status) } // Handling JSON marshaling with special characters unescaped. @@ -778,7 +741,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu enc := json.NewEncoder(buffer) enc.SetEscapeHTML(false) - if err := enc.Encode(request); err != nil { + if err := enc.Encode(req); err != nil { return fmt.Errorf("failed to marshal data: %v", err) } @@ -829,7 +792,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu evt = line } - var c completion + var c CompletionResponse if err := json.Unmarshal(evt, &c); err != nil { return fmt.Errorf("error unmarshalling llm prediction response: %v", err) } @@ -853,20 +816,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu }) } - if c.Stop { - doneReason := "stop" - if c.StoppedLimit { - doneReason = "length" - } - - fn(CompletionResponse{ - Done: true, - DoneReason: doneReason, - PromptEvalCount: c.Timings.PromptN, - PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), - EvalCount: c.Timings.PredictedN, - EvalDuration: parseDurationMs(c.Timings.PredictedMS), - }) + if c.Done { + fn(c) return nil } } @@ -914,7 +865,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err if err != nil { return nil, err } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) + return nil, fmt.Errorf("unexpected server status: %s", status) } data, err := json.Marshal(EmbeddingRequest{Content: input}) @@ -1059,12 +1010,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 { } return 0 } - -func parseDurationMs(ms float64) time.Duration { - dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) - if err != nil { - panic(err) - } - - return dur -} diff --git a/model/input/input.go b/model/input/input.go index 0cb3f3f4..30bdcf06 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -15,6 +15,12 @@ type Input struct { // stored in Multimodal, used for caching and comparing // equality. MultimodalHash uint64 + + // SameBatch forces the following number of tokens to be processed + // in a single batch, breaking and extending batches as needed. + // Useful for things like images that must be processed in one + // shot. + SameBatch int } // MultimodalIndex is a multimodal element (such as an image) diff --git a/model/model.go b/model/model.go index fadea324..53e47add 100644 --- a/model/model.go +++ b/model/model.go @@ -60,7 +60,7 @@ type MultimodalProcessor interface { // This function is also responsible for updating MultimodalHash for any Multimodal // that is modified to ensure that there is a unique hash value that accurately // represents the contents. - PostTokenize(ml.Context, []input.Input) ([]input.Input, error) + PostTokenize([]input.Input) ([]input.Input, error) } // Base implements the common fields and methods for all models diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 24193f15..32ad80f4 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -2,10 +2,9 @@ package gemma3 import ( "bytes" - "encoding/binary" - "hash/fnv" "image" "math" + "slices" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" @@ -112,36 +111,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return visionOutputs, nil } -type imageToken struct { - embedding ml.Tensor - index int -} - -func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var result []input.Input - fnvHash := fnv.New64a() for _, inp := range inputs { if inp.Multimodal == nil { result = append(result, inp) } else { - imageInputs := []input.Input{ - {Token: 108}, // "\n\n" - {Token: 255999}, // """ - } - result = append(result, imageInputs...) - - // add image embeddings inputMultimodal := inp.Multimodal.(ml.Tensor) - for i := range inputMultimodal.Dim(1) { - fnvHash.Reset() - binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash) - fnvHash.Write([]byte{byte(i)}) + result = append(result, + input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" + input.Input{Token: 255999}, // """ + input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder + ) - imageToken := imageToken{embedding: inputMultimodal, index: i} - result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()}) - } + // add image token placeholders + result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) result = append(result, input.Input{Token: 256000}, // diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 7a88c092..567f65a5 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -171,53 +171,20 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, return hiddenState.Add(ctx, residual) } -func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int { - var embedding ml.Tensor - var src, dst, length int - var except []int - - for _, image := range multimodal { - imageToken := image.Multimodal.(imageToken) - imageSrc := imageToken.index - imageDst := image.Index - - if embedding == nil { - embedding = imageToken.embedding - src = imageSrc - dst = imageDst - length = 1 - } else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst { - src = imageSrc - dst = imageDst - length++ - } else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst { - length++ - } else { - visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0)) - ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0)))) - - embedding = imageToken.embedding - src = imageSrc - dst = imageDst - length = 1 - } - - except = append(except, imageDst) - } - - if embedding != nil { - visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0)) - ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0)))) - } - - return except -} - func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) - except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal) + // set image embeddings + var except []int + for _, image := range opts.Multimodal { + visionOutputs := image.Multimodal.(ml.Tensor) + ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) + + for i := range visionOutputs.Dim(1) { + except = append(except, image.Index+i) + } + } for i, layer := range m.Layers { // gemma alternates between the sliding window (local) and causal (global) diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 071d77ac..fa4d570c 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return m.Projector.Forward(ctx, crossAttentionStates), nil } -func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var images []input.Input fnvHash := fnv.New64a() for i := range inputs { if inputs[i].Multimodal == nil { if len(images) > 0 { - inputs[i].Multimodal = images[0].Multimodal + inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)} inputs[i].MultimodalHash = images[0].MultimodalHash for j := 1; j < len(images); j++ { - inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3) + inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor)) fnvHash.Reset() binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash) @@ -138,7 +138,10 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { var crossAttentionStates ml.Tensor if len(opts.Multimodal) > 0 { - crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor) + images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor) + if len(images) > 0 { + crossAttentionStates = images[len(images)-1] + } } inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 8662afc1..83802d60 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -24,6 +24,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/runner/common" ) @@ -99,7 +100,7 @@ type NewSequenceParams struct { embedding bool } -func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { +func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() startTime := time.Now() @@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // generating image embeddings for each image -func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) { var inputs []input var parts []string var matches [][]string @@ -229,7 +230,7 @@ type Server struct { image *ImageContext // status for external health reporting - loading, ready to serve, etc. - status ServerStatus + status llm.ServerStatus // current progress on loading the model progress float32 @@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return nil } -// TODO (jmorganca): use structs from the api package to avoid duplication -// this way the api acts as a proxy instead of using a different api for the -// runner -type Options struct { - api.Runner - - NumKeep int `json:"n_keep"` - Seed int `json:"seed"` - NumPredict int `json:"n_predict"` - TopK int `json:"top_k"` - TopP float32 `json:"top_p"` - MinP float32 `json:"min_p"` - TypicalP float32 `json:"typical_p"` - RepeatLastN int `json:"repeat_last_n"` - Temperature float32 `json:"temperature"` - RepeatPenalty float32 `json:"repeat_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - FrequencyPenalty float32 `json:"frequency_penalty"` - Mirostat int `json:"mirostat"` - MirostatTau float32 `json:"mirostat_tau"` - MirostatEta float32 `json:"mirostat_eta"` - Stop []string `json:"stop"` -} - -type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` - AspectRatioID int `json:"aspect_ratio_id"` -} - -type CompletionRequest struct { - Prompt string `json:"prompt"` - Images []ImageData `json:"image_data"` - Grammar string `json:"grammar"` - CachePrompt bool `json:"cache_prompt"` - - Options -} - -type Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` -} - -type CompletionResponse struct { - Content string `json:"content"` - Stop bool `json:"stop"` - - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` - StoppedLimit bool `json:"stopped_limit,omitempty"` - PredictedN int `json:"predicted_n,omitempty"` - PredictedMS float64 `json:"predicted_ms,omitempty"` - PromptN int `json:"prompt_n,omitempty"` - PromptMS float64 `json:"prompt_ms,omitempty"` - - Timings Timings `json:"timings"` -} - func (s *Server) completion(w http.ResponseWriter, r *http.Request) { - var req CompletionRequest - req.Options = Options(api.DefaultOptions()) + var req llm.CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + // Set the headers to indicate streaming w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") @@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - var samplingParams llama.SamplingParams - samplingParams.TopK = req.TopK - samplingParams.TopP = req.TopP - samplingParams.MinP = req.MinP - samplingParams.TypicalP = req.TypicalP - samplingParams.Temp = req.Temperature - samplingParams.RepeatLastN = req.RepeatLastN - samplingParams.PenaltyRepeat = req.RepeatPenalty - samplingParams.PenaltyFreq = req.FrequencyPenalty - samplingParams.PenaltyPresent = req.PresencePenalty - samplingParams.Mirostat = req.Mirostat - samplingParams.MirostatTau = req.MirostatTau - samplingParams.MirostatEta = req.MirostatEta - samplingParams.Seed = uint32(req.Seed) - samplingParams.Grammar = req.Grammar + // Extract options from the CompletionRequest + samplingParams := llama.SamplingParams{ + TopK: req.Options.TopK, + TopP: req.Options.TopP, + MinP: req.Options.MinP, + TypicalP: req.Options.TypicalP, + Temp: req.Options.Temperature, + RepeatLastN: req.Options.RepeatLastN, + PenaltyRepeat: req.Options.RepeatPenalty, + PenaltyFreq: req.Options.FrequencyPenalty, + PenaltyPresent: req.Options.PresencePenalty, + Mirostat: req.Options.Mirostat, + MirostatTau: req.Options.MirostatTau, + MirostatEta: req.Options.MirostatEta, + Seed: uint32(req.Options.Seed), + Grammar: req.Grammar, + } seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.NumPredict, - stop: req.Stop, - numKeep: req.NumKeep, + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: req.Options.NumKeep, samplingParams: &samplingParams, embedding: false, }) @@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { // Send the final response - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Stop: true, - StoppedLimit: seq.doneReason == "limit", - Timings: Timings{ - PromptN: seq.numPromptInputs, - PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), - PredictedN: seq.numDecoded, - PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), - }, + doneReason := "stop" + if seq.doneReason == "limit" { + doneReason = "length" + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ + Done: true, + DoneReason: doneReason, + PromptEvalCount: seq.numPromptInputs, + PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + EvalCount: seq.numDecoded, + EvalDuration: time.Since(seq.startGenerationTime), }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) } @@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } -type EmbeddingRequest struct { - Content string `json:"content"` - CachePrompt bool `json:"cache_prompt"` -} - -type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` -} - func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { - var req EmbeddingRequest + var req llm.EmbeddingRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest) return @@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { embedding := <-seq.embedding - if err := json.NewEncoder(w).Encode(&EmbeddingResponse{ + if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ Embedding: embedding, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } } -type HealthResponse struct { - Status string `json:"status"` - Progress float32 `json:"progress"` -} - -type ServerStatus int - -const ( - ServerStatusReady ServerStatus = iota - ServerStatusLoadingModel - ServerStatusError -) - -func (s ServerStatus) ToString() string { - switch s { - case ServerStatusReady: - return "ok" - case ServerStatusLoadingModel: - return "loading model" - default: - return "server error" - } -} - func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(&HealthResponse{ - Status: s.status.ToString(), + if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ + Status: s.status, Progress: s.progress, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -879,7 +794,7 @@ func (s *Server) loadModel( panic(err) } - s.status = ServerStatusReady + s.status = llm.ServerStatusReady s.ready.Done() } @@ -937,7 +852,7 @@ func Execute(args []string) error { parallel: *parallel, seqs: make([]*Sequence, *parallel), seqsSem: semaphore.NewWeighted(int64(*parallel)), - status: ServerStatusLoadingModel, + status: llm.ServerStatusLoadingModel, } var tensorSplitFloats []float32 diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index a411fddb..adcb3f73 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -107,6 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp return nil, nil, err } + // TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved? if !cachePrompt { numPast = 0 } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index c380ef22..d4c24556 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -24,6 +24,7 @@ import ( "golang.org/x/sync/semaphore" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -33,10 +34,14 @@ import ( _ "github.com/ollama/ollama/model/models" ) +type contextList struct { + list []ml.Context +} + type Sequence struct { - // ctx for allocating tensors that last the lifetime of the sequence, such as + // ctxs are used for allocating tensors that last the lifetime of the sequence, such as // multimodal embeddings - ctx ml.Context + ctxs *contextList // batch index iBatch int @@ -94,13 +99,12 @@ type NewSequenceParams struct { embedding bool } -func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { +func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() startTime := time.Now() - ctx := s.model.Backend().NewContext() - inputs, err := s.inputs(ctx, prompt, images) + inputs, ctxs, err := s.inputs(prompt, images) if err != nil { return nil, fmt.Errorf("failed to process inputs: %w", err) } else if len(inputs) == 0 { @@ -126,7 +130,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // TODO(jessegross): Ingest cached history for grammar return &Sequence{ - ctx: ctx, + ctxs: ctxs, inputs: inputs, numPromptInputs: len(inputs), startProcessingTime: startTime, @@ -145,7 +149,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) { +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) { var inputs []input.Input var parts []string var matches [][]string @@ -160,12 +164,19 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in parts = []string{prompt} } + var contexts contextList + runtime.AddCleanup(&contexts, func(ctxs []ml.Context) { + for _, ctx := range ctxs { + ctx.Close() + } + }, contexts.list) + postTokenize := false for i, part := range parts { // text - tokenize tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) if err != nil { - return nil, err + return nil, nil, err } for _, t := range tokens { @@ -185,12 +196,14 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in } if imageIndex < 0 { - return nil, fmt.Errorf("invalid image index: %d", n) + return nil, nil, fmt.Errorf("invalid image index: %d", n) } + ctx := s.model.Backend().NewContext() + contexts.list = append(contexts.list, ctx) imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data) if err != nil { - return nil, err + return nil, nil, err } s.multimodalHash.Reset() @@ -204,13 +217,13 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in if visionModel && postTokenize { var err error - inputs, err = multimodalProcessor.PostTokenize(ctx, inputs) + inputs, err = multimodalProcessor.PostTokenize(inputs) if err != nil { - return nil, err + return nil, nil, err } } - return inputs, nil + return inputs, &contexts, nil } type Server struct { @@ -222,7 +235,7 @@ type Server struct { model model.Model // status for external health reporting - loading, ready to serve, etc. - status ServerStatus + status llm.ServerStatus // current progress on loading the model progress float32 @@ -305,7 +318,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) { close(seq.responses) close(seq.embedding) seq.cache.InUse = false - seq.ctx.Close() s.seqs[seqIndex] = nil s.seqsSem.Release(1) } @@ -351,6 +363,8 @@ func (s *Server) processBatch() error { seq.cache.Inputs = []input.Input{} } + batchSize := s.batchSize + for j, inp := range seq.inputs { if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { if len(seq.pendingInputs) == 0 { @@ -363,7 +377,15 @@ func (s *Server) processBatch() error { } } - if j >= s.batchSize { + // If we are required to put following inputs into a single batch then extend the + // batch size. Since we are only extending the size the minimum amount possible, this + // will cause a break if we have pending inputs. + minBatch := 1 + inp.SameBatch + if minBatch > batchSize { + batchSize = minBatch + } + + if len(seq.pendingInputs)+minBatch > batchSize { break } @@ -501,75 +523,18 @@ func (s *Server) processBatch() error { return nil } -// TODO (jmorganca): use structs from the api package to avoid duplication -// this way the api acts as a proxy instead of using a different api for the -// runner -type Options struct { - api.Runner - - NumKeep int `json:"n_keep"` - Seed int `json:"seed"` - NumPredict int `json:"n_predict"` - TopK int `json:"top_k"` - TopP float32 `json:"top_p"` - MinP float32 `json:"min_p"` - TypicalP float32 `json:"typical_p"` - RepeatLastN int `json:"repeat_last_n"` - Temperature float32 `json:"temperature"` - RepeatPenalty float32 `json:"repeat_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - FrequencyPenalty float32 `json:"frequency_penalty"` - Mirostat int `json:"mirostat"` - MirostatTau float32 `json:"mirostat_tau"` - MirostatEta float32 `json:"mirostat_eta"` - Stop []string `json:"stop"` -} - -type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` - AspectRatioID int `json:"aspect_ratio_id"` -} - -type CompletionRequest struct { - Prompt string `json:"prompt"` - Images []ImageData `json:"image_data"` - Grammar string `json:"grammar"` - CachePrompt bool `json:"cache_prompt"` - - Options -} - -type Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` -} - -type CompletionResponse struct { - Content string `json:"content"` - Stop bool `json:"stop"` - - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` - StoppedLimit bool `json:"stopped_limit,omitempty"` - PredictedN int `json:"predicted_n,omitempty"` - PredictedMS float64 `json:"predicted_ms,omitempty"` - PromptN int `json:"prompt_n,omitempty"` - PromptMS float64 `json:"prompt_ms,omitempty"` - - Timings Timings `json:"timings"` -} - func (s *Server) completion(w http.ResponseWriter, r *http.Request) { - var req CompletionRequest - req.Options = Options(api.DefaultOptions()) + var req llm.CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + // Set the headers to indicate streaming w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") @@ -591,18 +556,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } sampler := sample.NewSampler( - req.Temperature, - req.TopK, - req.TopP, - req.MinP, - req.Seed, + req.Options.Temperature, + req.Options.TopK, + req.Options.TopP, + req.Options.MinP, + req.Options.Seed, grammar, ) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.NumPredict, - stop: req.Stop, - numKeep: int32(req.NumKeep), + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: int32(req.Options.NumKeep), sampler: sampler, embedding: false, }) @@ -625,7 +590,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -652,7 +617,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -663,15 +628,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { // Send the final response - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Stop: true, - StoppedLimit: seq.doneReason == "limit", - Timings: Timings{ - PromptN: seq.numPromptInputs, - PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), - PredictedN: seq.numPredicted, - PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), - }, + doneReason := "stop" + if seq.doneReason == "limit" { + doneReason = "length" + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ + Done: true, + DoneReason: doneReason, + PromptEvalCount: seq.numPromptInputs, + PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + EvalCount: seq.numPredicted, + EvalDuration: time.Since(seq.startGenerationTime), }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) } @@ -682,43 +649,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } -type EmbeddingRequest struct { - Content string `json:"content"` - CachePrompt bool `json:"cache_prompt"` -} - -type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` -} - -type HealthResponse struct { - Status string `json:"status"` - Progress float32 `json:"progress"` -} - -type ServerStatus int - -const ( - ServerStatusReady ServerStatus = iota - ServerStatusLoadingModel - ServerStatusError -) - -func (s ServerStatus) ToString() string { - switch s { - case ServerStatusReady: - return "ok" - case ServerStatusLoadingModel: - return "loading model" - default: - return "server error" - } -} - func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(&HealthResponse{ - Status: s.status.ToString(), + if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ + Status: s.status, Progress: s.progress, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -772,7 +706,7 @@ func (s *Server) loadModel( s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) - s.status = ServerStatusReady + s.status = llm.ServerStatusReady s.ready.Done() } @@ -824,7 +758,7 @@ func Execute(args []string) error { server := &Server{ batchSize: *batchSize, - status: ServerStatusLoadingModel, + status: llm.ServerStatusLoadingModel, } // TODO(jessegross): Parameters that need to be implemented: diff --git a/scripts/build_darwin.sh b/scripts/build_darwin.sh index 76d0a6c2..616e8501 100755 --- a/scripts/build_darwin.sh +++ b/scripts/build_darwin.sh @@ -8,7 +8,7 @@ usage() { exit 1 } -export VERSION=${VERSION:-$(git describe --tags --dirty)} +export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")} export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" export CGO_CPPFLAGS='-mmacosx-version-min=11.3' diff --git a/server/internal/cache/blob/chunked.go b/server/internal/cache/blob/chunked.go index 5faea84f..3f62127a 100644 --- a/server/internal/cache/blob/chunked.go +++ b/server/internal/cache/blob/chunked.go @@ -5,11 +5,18 @@ import ( "errors" "io" "os" - - "github.com/ollama/ollama/server/internal/chunks" ) -type Chunk = chunks.Chunk // TODO: move chunks here? +// Chunk represents a range of bytes in a blob. +type Chunk struct { + Start int64 + End int64 +} + +// Size returns end minus start plus one. +func (c Chunk) Size() int64 { + return c.End - c.Start + 1 +} // Chunker writes to a blob in chunks. // Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker. diff --git a/server/internal/chunks/chunks.go b/server/internal/chunks/chunks.go deleted file mode 100644 index 7bb4e99a..00000000 --- a/server/internal/chunks/chunks.go +++ /dev/null @@ -1,81 +0,0 @@ -package chunks - -import ( - "fmt" - "iter" - "strconv" - "strings" -) - -type Chunk struct { - Start, End int64 -} - -func New(start, end int64) Chunk { - return Chunk{start, end} -} - -// ParseRange parses a string in the form "unit=range" where unit is a string -// and range is a string in the form "start-end". It returns the unit and the -// range as a Chunk. -func ParseRange(s string) (unit string, _ Chunk, _ error) { - unit, r, _ := strings.Cut(s, "=") - if r == "" { - return unit, Chunk{}, nil - } - c, err := Parse(r) - if err != nil { - return "", Chunk{}, err - } - return unit, c, err -} - -// Parse parses a string in the form "start-end" and returns the Chunk. -func Parse[S ~string | ~[]byte](s S) (Chunk, error) { - startPart, endPart, found := strings.Cut(string(s), "-") - if !found { - return Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s) - } - start, err := strconv.ParseInt(startPart, 10, 64) - if err != nil { - return Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err) - } - end, err := strconv.ParseInt(endPart, 10, 64) - if err != nil { - return Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err) - } - if start > end { - return Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s) - } - return Chunk{start, end}, nil -} - -// Of returns a sequence of contiguous Chunks of size chunkSize that cover -// the range [0, size), in order. -func Of(size, chunkSize int64) iter.Seq[Chunk] { - return func(yield func(Chunk) bool) { - for start := int64(0); start < size; start += chunkSize { - end := min(start+chunkSize-1, size-1) - if !yield(Chunk{start, end}) { - break - } - } - } -} - -// Count returns the number of Chunks of size chunkSize needed to cover the -// range [0, size). -func Count(size, chunkSize int64) int64 { - return (size + chunkSize - 1) / chunkSize -} - -// Size returns end minus start plus one. -func (c Chunk) Size() int64 { - return c.End - c.Start + 1 -} - -// String returns the string representation of the Chunk in the form -// "{start}-{end}". -func (c Chunk) String() string { - return fmt.Sprintf("%d-%d", c.Start, c.End) -} diff --git a/server/internal/chunks/chunks_test.go b/server/internal/chunks/chunks_test.go deleted file mode 100644 index c23e0de8..00000000 --- a/server/internal/chunks/chunks_test.go +++ /dev/null @@ -1,65 +0,0 @@ -package chunks - -import ( - "slices" - "testing" -) - -func TestOf(t *testing.T) { - cases := []struct { - total int64 - chunkSize int64 - want []Chunk - }{ - {0, 1, nil}, - {1, 1, []Chunk{{0, 0}}}, - {1, 2, []Chunk{{0, 0}}}, - {2, 1, []Chunk{{0, 0}, {1, 1}}}, - {10, 9, []Chunk{{0, 8}, {9, 9}}}, - } - - for _, tt := range cases { - got := slices.Collect(Of(tt.total, tt.chunkSize)) - if !slices.Equal(got, tt.want) { - t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want) - } - } -} - -func TestSize(t *testing.T) { - cases := []struct { - c Chunk - want int64 - }{ - {Chunk{0, 0}, 1}, - {Chunk{0, 1}, 2}, - {Chunk{3, 4}, 2}, - } - - for _, tt := range cases { - got := tt.c.Size() - if got != tt.want { - t.Errorf("%v: got %d; want %d", tt.c, got, tt.want) - } - } -} - -func TestCount(t *testing.T) { - cases := []struct { - total int64 - chunkSize int64 - want int64 - }{ - {0, 1, 0}, - {1, 1, 1}, - {1, 2, 1}, - {2, 1, 2}, - {10, 9, 2}, - } - for _, tt := range cases { - got := Count(tt.total, tt.chunkSize) - if got != tt.want { - t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want) - } - } -} diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index baf42262..d1d01ba4 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -25,6 +25,7 @@ import ( "os" "path/filepath" "runtime" + "runtime/debug" "slices" "strconv" "strings" @@ -36,7 +37,6 @@ import ( "golang.org/x/sync/errgroup" "github.com/ollama/ollama/server/internal/cache/blob" - "github.com/ollama/ollama/server/internal/chunks" "github.com/ollama/ollama/server/internal/internal/backoff" "github.com/ollama/ollama/server/internal/internal/names" @@ -260,6 +260,7 @@ func DefaultRegistry() (*Registry, error) { } var rc Registry + rc.UserAgent = UserAgent() rc.Key, err = ssh.ParseRawPrivateKey(keyPEM) if err != nil { return nil, err @@ -275,6 +276,16 @@ func DefaultRegistry() (*Registry, error) { return &rc, nil } +func UserAgent() string { + buildinfo, _ := debug.ReadBuildInfo() + return fmt.Sprintf("ollama/%s (%s %s) Go/%s", + buildinfo.Main.Version, + runtime.GOARCH, + runtime.GOOS, + runtime.Version(), + ) +} + func (r *Registry) maxStreams() int { return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0)) } @@ -500,7 +511,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error { if err != nil { return err } - req.Header.Set("Range", fmt.Sprintf("bytes=%s", cs.Chunk)) + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End)) res, err := sendRequest(r.client(), req) if err != nil { return err @@ -794,7 +805,7 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se yield(chunksum{}, err) return } - chunk, err := chunks.Parse(s.Bytes()) + chunk, err := parseChunk(s.Bytes()) if err != nil { yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes())) return @@ -1059,3 +1070,23 @@ func splitExtended(s string) (scheme, name, digest string) { } return scheme, s, digest } + +// parseChunk parses a string in the form "start-end" and returns the Chunk. +func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) { + startPart, endPart, found := strings.Cut(string(s), "-") + if !found { + return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s) + } + start, err := strconv.ParseInt(startPart, 10, 64) + if err != nil { + return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err) + } + end, err := strconv.ParseInt(endPart, 10, 64) + if err != nil { + return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err) + } + if start > end { + return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s) + } + return blob.Chunk{Start: start, End: end}, nil +} diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index ecfc6326..30fb58ab 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -21,7 +21,6 @@ import ( "time" "github.com/ollama/ollama/server/internal/cache/blob" - "github.com/ollama/ollama/server/internal/chunks" "github.com/ollama/ollama/server/internal/testutil" ) @@ -531,56 +530,6 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) { } } -func TestRegistryPullChunking(t *testing.T) { - t.Skip("TODO: BRING BACK BEFORE LANDING") - - rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { - t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range")) - if r.URL.Host != "blob.store" { - // The production registry redirects to the blob store. - http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound) - return - } - if strings.Contains(r.URL.Path, "/blobs/") { - rng := r.Header.Get("Range") - if rng == "" { - http.Error(w, "missing range", http.StatusBadRequest) - return - } - _, c, err := chunks.ParseRange(r.Header.Get("Range")) - if err != nil { - panic(err) - } - io.WriteString(w, "remote"[c.Start:c.End+1]) - return - } - fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote")) - }) - - // Force chunking by setting the threshold to less than the size of the - // layer. - rc.ChunkingThreshold = 3 - rc.MaxChunkSize = 3 - - var reads []int64 - ctx := WithTrace(t.Context(), &Trace{ - Update: func(d *Layer, n int64, err error) { - if err != nil { - t.Errorf("update %v %d %v", d, n, err) - } - reads = append(reads, n) - }, - }) - - err := rc.Pull(ctx, "remote") - testutil.Check(t, err) - - want := []int64{0, 3, 6} - if !slices.Equal(reads, want) { - t.Errorf("reads = %v; want %v", reads, want) - } -} - func TestRegistryResolveByDigest(t *testing.T) { check := testutil.Checker(t) diff --git a/server/internal/cmd/oppbench/oppbench.go b/server/internal/cmd/oppbench/oppbench.go deleted file mode 100644 index 7a530594..00000000 --- a/server/internal/cmd/oppbench/oppbench.go +++ /dev/null @@ -1,11 +0,0 @@ -package main - -import ( - "fmt" - "os" -) - -func main() { - fmt.Println("Run as 'go test -bench=.' to run the benchmarks") - os.Exit(1) -} diff --git a/server/internal/cmd/oppbench/oppbench_test.go b/server/internal/cmd/oppbench/oppbench_test.go deleted file mode 100644 index c71d6cde..00000000 --- a/server/internal/cmd/oppbench/oppbench_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package main - -import ( - "bytes" - "context" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "runtime" - "sync/atomic" - "testing" - "time" - - "github.com/ollama/ollama/server/internal/chunks" - "golang.org/x/sync/errgroup" -) - -func BenchmarkDownload(b *testing.B) { - run := func(fileSize, chunkSize int64) { - name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize) - b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) }) - } - - run(100<<20, 8<<20) - run(100<<20, 16<<20) - run(100<<20, 32<<20) - run(100<<20, 64<<20) - run(100<<20, 128<<20) // 1 chunk -} - -func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error { - const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d" - req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil) - if err != nil { - return err - } - req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk)) - res, err := c.Do(req) - if err != nil { - return err - } - defer res.Body.Close() - - _, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read - return err -} - -var sleepTime atomic.Int64 - -func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) { - client := &http.Client{ - Transport: func() http.RoundTripper { - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.DisableKeepAlives = true - return tr - }(), - } - defer client.CloseIdleConnections() - - // warm up the client - run(context.Background(), client, chunks.New(0, 1<<20)) - - b.SetBytes(fileSize) - b.ReportAllocs() - - // Give our CDN a min to breathe between benchmarks. - time.Sleep(time.Duration(sleepTime.Swap(3))) - - for b.Loop() { - g, ctx := errgroup.WithContext(b.Context()) - g.SetLimit(runtime.GOMAXPROCS(0)) - for chunk := range chunks.Of(fileSize, chunkSize) { - g.Go(func() error { return run(ctx, client, chunk) }) - } - if err := g.Wait(); err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkWrite(b *testing.B) { - b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) }) -} - -func benchmarkWrite(b *testing.B, chunkSize int) { - b.ReportAllocs() - - dir := b.TempDir() - f, err := os.Create(filepath.Join(dir, "write-single")) - if err != nil { - b.Fatal(err) - } - defer f.Close() - - data := make([]byte, chunkSize) - b.SetBytes(int64(chunkSize)) - r := bytes.NewReader(data) - for b.Loop() { - r.Reset(data) - _, err := io.Copy(f, r) - if err != nil { - b.Fatal(err) - } - } -} diff --git a/server/prompt.go b/server/prompt.go index d053f2a8..5b5b958f 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -26,7 +26,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. var system []api.Message isMllama := checkMllamaModelFamily(m) - isGemma3 := checkGemma3ModelFamily(m) var imageNumTokens int // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent @@ -41,7 +40,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. n := len(msgs) - 1 // in reverse, find all messages that fit into context window for i := n; i >= 0; i-- { - if (isMllama || isGemma3) && len(msgs[i].Images) > 1 { + if isMllama && len(msgs[i].Images) > 1 { return "", nil, errTooManyImages } @@ -158,12 +157,3 @@ func checkMllamaModelFamily(m *Model) bool { } return false } - -func checkGemma3ModelFamily(m *Model) bool { - for _, arch := range m.Config.ModelFamilies { - if arch == "gemma3" { - return true - } - } - return false -}