diff --git a/integration/embed_test.go b/integration/embed_test.go index eb00f4ba..a6852448 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" ) @@ -44,9 +45,8 @@ func TestAllMiniLMEmbeddings(t *testing.T) { } res, err := embeddingTestHelper(ctx, client, t, req) - if err != nil { - t.Fatalf("error: %v", err) + t.Fatal(err) } if len(res.Embedding) != 384 { @@ -74,9 +74,8 @@ func TestAllMiniLMEmbed(t *testing.T) { } res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatalf("error: %v", err) + t.Fatal(err) } if len(res.Embeddings) != 1 { @@ -112,9 +111,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { } res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatalf("error: %v", err) + t.Fatal(err) } if len(res.Embeddings) != 2 { @@ -156,93 +154,135 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { truncTrue, truncFalse := true, false - type testReq struct { - Name string - Request api.EmbedRequest + want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{ + Model: "all-minilm", + Input: "why", + }) + if err != nil { + t.Fatal(err) } - reqs := []testReq{ + cases := []struct { + name string + request api.EmbedRequest + check func(*api.EmbedResponse, error) + }{ { - Name: "Target Truncation", - Request: api.EmbedRequest{ + name: "target truncation", + request: api.EmbedRequest{ Model: "all-minilm", Input: "why", }, - }, - { - Name: "Default Truncate", - Request: api.EmbedRequest{ - Model: "all-minilm", - Input: "why is the sky blue?", - Options: map[string]any{"num_ctx": 1}, + check: func(got *api.EmbedResponse, err error) { + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { + t.Errorf("embedding mismatch (-want +got):\n%s", diff) + } }, }, { - Name: "Explicit Truncate", - Request: api.EmbedRequest{ + name: "default truncate", + request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Options: map[string]any{"num_ctx": 3}, + }, + check: func(got *api.EmbedResponse, err error) { + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { + t.Errorf("embedding mismatch (-want +got):\n%s", diff) + } + }, + }, + { + name: "explicit truncate", + request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 3}, + }, + check: func(got *api.EmbedResponse, err error) { + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { + t.Errorf("embedding mismatch (-want +got):\n%s", diff) + } + }, + }, + { + name: "truncate error", + request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 3}, + }, + check: func(res *api.EmbedResponse, err error) { + if err.Error() != "input exceeds maximum context length" { + t.Fatalf("expected truncation error, got: %v", err) + } + }, + }, + { + name: "input after truncate error", + request: api.EmbedRequest{ Model: "all-minilm", Input: "why is the sky blue?", Truncate: &truncTrue, Options: map[string]any{"num_ctx": 1}, }, + check: func(res *api.EmbedResponse, err error) { + if err.Error() != "input after truncation exceeds maximum context length" { + t.Fatalf("expected truncation error, got: %v", err) + } + }, + }, + { + name: "input after truncate error", + request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 0}, + }, + check: func(res *api.EmbedResponse, err error) { + if err.Error() != "input after truncation exceeds maximum context length" { + t.Fatalf("expected truncation error, got: %v", err) + } + }, }, } - res := make(map[string]*api.EmbedResponse) - - for _, req := range reqs { - response, err := embedTestHelper(ctx, client, t, req.Request) - if err != nil { - t.Fatalf("error: %v", err) - } - res[req.Name] = response - } - - if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] { - t.Fatal("expected default request to truncate correctly") - } - - if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] { - t.Fatal("expected default request and truncate true request to be the same") - } - - // check that truncate set to false returns an error if context length is exceeded - _, err := embedTestHelper(ctx, client, t, api.EmbedRequest{ - Model: "all-minilm", - Input: "why is the sky blue?", - Truncate: &truncFalse, - Options: map[string]any{"num_ctx": 1}, - }) - - if err == nil { - t.Fatal("expected error, got nil") + for _, req := range cases { + t.Run(req.name, func(t *testing.T) { + req.check(embedTestHelper(ctx, client, t, req.request)) + }) } } func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { + t.Helper() + if err := PullIfMissing(ctx, client, req.Model); err != nil { - t.Fatalf("failed to pull model %s: %v", req.Model, err) + t.Fatal(err) } - response, err := client.Embeddings(ctx, &req) - - if err != nil { - return nil, err - } - - return response, nil + return client.Embeddings(ctx, &req) } func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { + t.Helper() + if err := PullIfMissing(ctx, client, req.Model); err != nil { - t.Fatalf("failed to pull model %s: %v", req.Model, err) + t.Fatal(err) } - response, err := client.Embed(ctx, &req) - - if err != nil { - return nil, err - } - - return response, nil + return client.Embed(ctx, &req) } diff --git a/server/routes.go b/server/routes.go index b1def0de..c0204531 100644 --- a/server/routes.go +++ b/server/routes.go @@ -634,7 +634,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) if len(tokens) > ctxLen { if !truncate { - c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"}) + c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"}) return } @@ -646,6 +646,13 @@ func (s *Server) EmbedHandler(c *gin.Context) { ctxLen-- } + slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens)) + if ctxLen <= 0 { + // return error if the truncated input would be empty or just special tokens + c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"}) + return + } + tokens = tokens[:ctxLen] s, err = r.Detokenize(c.Request.Context(), tokens)