Merge branch 'ollama:main' into main

This commit is contained in:
likelovewant
2024-07-03 10:58:24 +08:00
committed by GitHub
40 changed files with 1129 additions and 217 deletions

View File

@@ -70,12 +70,12 @@ RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
FROM --platform=linux/arm64 rockylinux:8 AS cpu-builder-arm64
ARG CMAKE_VERSION
ARG GOLANG_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
ARG OLLAMA_CUSTOM_CPU_DEFS
ARG CGO_CFLAGS

View File

@@ -345,6 +345,13 @@ type ProcessModelResponse struct {
SizeVRAM int64 `json:"size_vram"`
}
type RetrieveModelResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type TokenResponse struct {
Token string `json:"token"`
}

View File

@@ -266,8 +266,10 @@ If there is insufficient available memory to load a new model request while one
Parallel request processing for a given model results in increasing the context size by the number of parallel requests. For example, a 2K context with 4 parallel requests will result in an 8K context and additional memory allocation.
The following server settings may be used to adjust how Ollama handles concurrent requests:
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

View File

@@ -65,6 +65,7 @@ curl http://localhost:11434/v1/chat/completions \
}
]
}'
```
## Endpoints

View File

@@ -12,6 +12,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
type Error struct {
@@ -42,6 +43,12 @@ type ChunkChoice struct {
FinishReason *string `json:"finish_reason"`
}
type CompleteChunkChoice struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason *string `json:"finish_reason"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
@@ -85,6 +92,51 @@ type ChatCompletionChunk struct {
Choices []ChunkChoice `json:"choices"`
}
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens *int `json:"max_tokens"`
PresencePenalty float32 `json:"presence_penalty"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Stream bool `json:"stream"`
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
}
type Completion struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []CompleteChunkChoice `json:"choices"`
Usage Usage `json:"usage,omitempty"`
}
type CompletionChunk struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []CompleteChunkChoice `json:"choices"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
}
type Model struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type ListCompletion struct {
Object string `json:"object"`
Data []Model `json:"data"`
}
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
@@ -145,7 +197,79 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
}
}
func fromRequest(r ChatCompletionRequest) api.ChatRequest {
func toCompletion(id string, r api.GenerateResponse) Completion {
return Completion{
Id: id,
Object: "text_completion",
Created: r.CreatedAt.Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []CompleteChunkChoice{{
Text: r.Response,
Index: 0,
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
}},
Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
}
}
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
return CompletionChunk{
Id: id,
Object: "text_completion",
Created: time.Now().Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []CompleteChunkChoice{{
Text: r.Response,
Index: 0,
FinishReason: func(reason string) *string {
if len(reason) > 0 {
return &reason
}
return nil
}(r.DoneReason),
}},
}
}
func toListCompletion(r api.ListResponse) ListCompletion {
var data []Model
for _, m := range r.Models {
data = append(data, Model{
Id: m.Name,
Object: "model",
Created: m.ModifiedAt.Unix(),
OwnedBy: model.ParseName(m.Name).Namespace,
})
}
return ListCompletion{
Object: "list",
Data: data,
}
}
func toModel(r api.ShowResponse, m string) Model {
return Model{
Id: m,
Object: "model",
Created: r.ModifiedAt.Unix(),
OwnedBy: model.ParseName(m).Namespace,
}
}
func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
var messages []api.Message
for _, msg := range r.Messages {
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
@@ -156,7 +280,7 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []interface{}:
case []any:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
@@ -208,13 +332,78 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
}
}
type writer struct {
stream bool
id string
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
options := make(map[string]any)
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []string:
options["stop"] = stop
default:
if r.Stop != nil {
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
}
}
if r.MaxTokens != nil {
options["num_predict"] = *r.MaxTokens
}
if r.Temperature != nil {
options["temperature"] = *r.Temperature * 2.0
} else {
options["temperature"] = 1.0
}
if r.Seed != nil {
options["seed"] = *r.Seed
}
options["frequency_penalty"] = r.FrequencyPenalty * 2.0
options["presence_penalty"] = r.PresencePenalty * 2.0
if r.TopP != 0.0 {
options["top_p"] = r.TopP
} else {
options["top_p"] = 1.0
}
return api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
Options: options,
Stream: &r.Stream,
}, nil
}
type BaseWriter struct {
gin.ResponseWriter
}
func (w *writer) writeError(code int, data []byte) (int, error) {
type ChatWriter struct {
stream bool
id string
BaseWriter
}
type CompleteWriter struct {
stream bool
id string
BaseWriter
}
type ListWriter struct {
BaseWriter
}
type RetrieveWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
@@ -230,7 +419,7 @@ func (w *writer) writeError(code int, data []byte) (int, error) {
return len(data), nil
}
func (w *writer) writeResponse(data []byte) (int, error) {
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
@@ -270,7 +459,7 @@ func (w *writer) writeResponse(data []byte) (int, error) {
return len(data), nil
}
func (w *writer) Write(data []byte) (int, error) {
func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
@@ -279,7 +468,176 @@ func (w *writer) Write(data []byte) (int, error) {
return w.writeResponse(data)
}
func Middleware() gin.HandlerFunc {
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// completion chunk
if w.stream {
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if generateResponse.Done {
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func (w *ListWriter) writeResponse(data []byte) (int, error) {
var listResponse api.ListResponse
err := json.Unmarshal(data, &listResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
var showResponse api.ShowResponse
err := json.Unmarshal(data, &showResponse)
if err != nil {
return 0, err
}
// retrieve completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func RetrieveMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
// response writer
w := &RetrieveWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: c.Param("model"),
}
c.Writer = w
c.Next()
}
}
func CompletionsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req CompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
genReq, err := fromCompleteRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req ChatCompletionRequest
err := c.ShouldBindJSON(&req)
@@ -294,17 +652,17 @@ func Middleware() gin.HandlerFunc {
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &writer{
ResponseWriter: c.Writer,
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
}
c.Writer = w

298
openai/openai_test.go Normal file
View File

@@ -0,0 +1,298 @@
package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/assert"
)
func TestMiddleware(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
TestPath string
Handler func() gin.HandlerFunc
Endpoint func(c *gin.Context)
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
}
testCases := []testCase{
{
Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
TestPath: "/api/chat",
Handler: ChatMiddleware,
Endpoint: func(c *gin.Context) {
var chatReq api.ChatRequest
if err := c.ShouldBindJSON(&chatReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
userMessage := chatReq.Messages[0].Content
var assistantMessage string
switch userMessage {
case "Hello":
assistantMessage = "Hello!"
default:
assistantMessage = "I'm not sure how to respond to that."
}
c.JSON(http.StatusOK, api.ChatResponse{
Message: api.Message{
Role: "assistant",
Content: assistantMessage,
},
})
},
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var chatResp ChatCompletion
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
t.Fatal(err)
}
if chatResp.Object != "chat.completion" {
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
}
if chatResp.Choices[0].Message.Content != "Hello!" {
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.GenerateResponse{
Response: "Hello!",
})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Hello!" {
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with params",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
var generateReq api.GenerateRequest
if err := c.ShouldBindJSON(&generateReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
return
}
temperature := generateReq.Options["temperature"].(float64)
var assistantMessage string
switch temperature {
case 1.6:
assistantMessage = "Received temperature of 1.6"
default:
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
}
c.JSON(http.StatusOK, api.GenerateResponse{
Response: assistantMessage,
})
},
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var completionResp Completion
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
t.Fatal(err)
}
if completionResp.Object != "text_completion" {
t.Fatalf("expected text_completion, got %s", completionResp.Object)
}
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
}
},
},
{
Name: "completions handler with error",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
t.Fatalf("error was not forwarded")
}
},
},
{
Name: "list handler",
Method: http.MethodGet,
Path: "/api/tags",
TestPath: "/api/tags",
Handler: ListMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{
{
Name: "Test Model",
},
},
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var listResp ListCompletion
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
t.Fatal(err)
}
if listResp.Object != "list" {
t.Fatalf("expected list, got %s", listResp.Object)
}
if len(listResp.Data) != 1 {
t.Fatalf("expected 1, got %d", len(listResp.Data))
}
if listResp.Data[0].Id != "Test Model" {
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
}
},
},
{
Name: "retrieve model",
Method: http.MethodGet,
Path: "/api/show/:model",
TestPath: "/api/show/test-model",
Handler: RetrieveMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
var retrieveResp Model
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
t.Fatal(err)
}
if retrieveResp.Object != "model" {
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
}
if retrieveResp.Id != "test-model" {
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
}
},
},
}
gin.SetMode(gin.TestMode)
router := gin.New()
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
router = gin.New()
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, tc.Endpoint)
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, resp)
})
}
}

View File

@@ -6,10 +6,21 @@ set -ex
MACHINE=$(uname -m)
if grep -i "centos" /etc/system-release >/dev/null; then
# As of 7/1/2024 mirrorlist.centos.org has been taken offline, so adjust accordingly
sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
# Centos 7 derivatives have too old of a git version to run our generate script
# uninstall and ignore failures
yum remove -y git
yum -y install epel-release centos-release-scl
# The release packages reinstate the mirrors, undo that again
sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
yum -y install dnf
if [ "${MACHINE}" = "x86_64" ]; then
yum -y install https://repo.ius.io/ius-release-el7.rpm

View File

@@ -28,11 +28,16 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
type Capability string
const CapabilityCompletion = Capability("completion")
type registryOptions struct {
Insecure bool
Username string
@@ -48,16 +53,43 @@ type Model struct {
ParentModel string
AdapterPaths []string
ProjectorPaths []string
Template string
System string
License []string
Digest string
Options map[string]interface{}
Messages []Message
Template *template.Template
}
func (m *Model) IsEmbedding() bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
func (m *Model) Has(caps ...Capability) bool {
for _, cap := range caps {
switch cap {
case CapabilityCompletion:
f, err := os.Open(m.ModelPath)
if err != nil {
slog.Error("couldn't open model file", "error", err)
continue
}
defer f.Close()
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
ggml, _, err := llm.DecodeGGML(f, 0)
if err != nil {
slog.Error("couldn't decode ggml", "error", err)
continue
}
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
return false
}
default:
slog.Error("unknown capability", "capability", cap)
return false
}
}
return true
}
func (m *Model) String() string {
@@ -82,10 +114,10 @@ func (m *Model) String() string {
})
}
if m.Template != "" {
if m.Template != nil {
modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "template",
Args: m.Template,
Args: m.Template.String(),
})
}
@@ -135,13 +167,6 @@ type Message struct {
Content string `json:"content"`
}
type ManifestV2 struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
}
type ConfigV2 struct {
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
@@ -160,7 +185,7 @@ type RootFS struct {
DiffIDs []string `json:"diff_ids"`
}
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
func GetManifest(mp ModelPath) (*Manifest, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
return nil, "", err
@@ -170,7 +195,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
return nil, "", err
}
var manifest *ManifestV2
var manifest *Manifest
bts, err := os.ReadFile(fp)
if err != nil {
@@ -198,8 +223,7 @@ func GetModel(name string) (*Model, error) {
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
Template: "{{ .Prompt }}",
License: []string{},
Template: template.DefaultTemplate,
}
filename, err := GetBlobsPath(manifest.Config.Digest)
@@ -235,13 +259,17 @@ func GetModel(name string) (*Model, error) {
model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename)
case "application/vnd.ollama.image.template":
case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.Template = string(bts)
model.Template, err = template.Parse(string(bts))
if err != nil {
return nil, err
}
case "application/vnd.ollama.image.system":
bts, err := os.ReadFile(filename)
if err != nil {
@@ -249,13 +277,6 @@ func GetModel(name string) (*Model, error) {
}
model.System = string(bts)
case "application/vnd.ollama.image.prompt":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.Template = string(bts)
case "application/vnd.ollama.image.params":
params, err := os.Open(filename)
if err != nil {
@@ -822,7 +843,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
var manifest *ManifestV2
var manifest *Manifest
var err error
var noprune string
@@ -929,7 +950,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header)
@@ -940,7 +961,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
}
defer resp.Body.Close()
var m *ManifestV2
var m *Manifest
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err
}

View File

@@ -14,7 +14,10 @@ import (
)
type Manifest struct {
ManifestV2
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
filepath string
fi os.FileInfo
@@ -66,7 +69,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
p := filepath.Join(manifests, n.Filepath())
var m ManifestV2
var m Manifest
f, err := os.Open(p)
if err != nil {
return nil, err
@@ -83,12 +86,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, err
}
return &Manifest{
ManifestV2: m,
filepath: p,
fi: fi,
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
}, nil
m.filepath = p
m.fi = fi
m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil))
return &m, nil
}
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
@@ -108,7 +110,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
}
defer f.Close()
m := ManifestV2{
m := Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,

View File

@@ -25,7 +25,7 @@ func createManifest(t *testing.T, path, name string) {
}
defer f.Close()
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
if err := json.NewEncoder(f).Encode(Manifest{}); err != nil {
t.Fatal(err)
}
}

View File

@@ -11,12 +11,11 @@ import (
"net/http"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/templates"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
@@ -91,12 +90,11 @@ func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse))
fn(api.ProgressResponse{Status: "unpacking model metadata"})
for _, f := range r.File {
n := filepath.Join(p, f.Name)
if !strings.HasPrefix(n, p) {
slog.Warn("skipped extracting file outside of context", "name", f.Name)
continue
if !filepath.IsLocal(f.Name) {
return fmt.Errorf("%w: %s", zip.ErrInsecurePath, f.Name)
}
n := filepath.Join(p, f.Name)
if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil {
return err
}
@@ -258,7 +256,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
for _, layer := range layers {
if s := layer.GGML.KV().ChatTemplate(); s != "" {
if t, err := templates.NamedTemplate(s); err != nil {
if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err)
} else {
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")

View File

@@ -3,10 +3,12 @@ package server
import (
"archive/zip"
"bytes"
"errors"
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/ollama/ollama/api"
@@ -39,13 +41,31 @@ func TestExtractFromZipFile(t *testing.T) {
cases := []struct {
name string
expect []string
err error
}{
{
name: "good",
expect: []string{"good"},
},
{
name: filepath.Join("..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"),
name: strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)),
expect: []string{filepath.Join("to", "good")},
},
{
name: strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)),
expect: []string{"good"},
},
{
name: strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)),
expect: []string{"good"},
},
{
name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)),
err: zip.ErrInsecurePath,
},
{
name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)),
err: zip.ErrInsecurePath,
},
}
@@ -55,7 +75,7 @@ func TestExtractFromZipFile(t *testing.T) {
defer f.Close()
tempDir := t.TempDir()
if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); err != nil {
if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) {
t.Fatal(err)
}

View File

@@ -4,10 +4,11 @@ import (
"fmt"
"log/slog"
"strings"
"text/template"
"text/template/parse"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
// isResponseNode checks if the node contains .Response
@@ -53,13 +54,8 @@ func formatTemplateForResponse(tmpl *template.Template, generate bool) {
// Prompt renders a prompt from a template. If generate is set to true,
// the response and parts of the template following it are not rendered
func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
if err != nil {
return "", err
}
formatTemplateForResponse(parsed, generate)
func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
formatTemplateForResponse(tmpl, generate)
vars := map[string]any{
"System": system,
@@ -68,14 +64,14 @@ func Prompt(tmpl, system, prompt, response string, generate bool) (string, error
}
var sb strings.Builder
if err := parsed.Execute(&sb, vars); err != nil {
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
}
return sb.String(), nil
}
func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
rendered, err := Prompt(tmpl, system, prompt, response, false)
if err != nil {
return 0, err
@@ -91,7 +87,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
}
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
type prompt struct {
System string
Prompt string

View File

@@ -5,6 +5,7 @@ import (
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
func TestPrompt(t *testing.T) {
@@ -61,7 +62,12 @@ func TestPrompt(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
tmpl, err := template.Parse(tc.template)
if err != nil {
t.Fatal(err)
}
got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
if err != nil {
t.Errorf("error = %v", err)
}
@@ -192,7 +198,12 @@ func TestChatPrompt(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
tmpl, err := template.Parse(tc.template)
if err != nil {
t.Fatal(err)
}
got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
if err != nil {
t.Errorf("error = %v", err)
}

View File

@@ -31,6 +31,7 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
@@ -121,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
if model.IsEmbedding() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
return
}
@@ -161,6 +162,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
tmpl, err := template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
checkpointLoaded := time.Now()
var prompt string
@@ -169,7 +176,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
prompt = req.Prompt
case req.Prompt != "":
if req.Template == "" {
req.Template = model.Template
tmpl = model.Template
}
if req.System == "" {
@@ -187,7 +194,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
sb.WriteString(req.Prompt)
p, err := Prompt(req.Template, req.System, sb.String(), "", true)
p, err := Prompt(tmpl, req.System, sb.String(), "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -242,7 +249,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -680,7 +687,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
if req.Template != "" {
m.Template = req.Template
m.Template, err = template.Parse(req.Template)
if err != nil {
return nil, err
}
}
msgs := make([]api.Message, 0)
@@ -701,7 +711,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
resp := &api.ShowResponse{
License: strings.Join(m.License, "\n"),
System: m.System,
Template: m.Template,
Template: m.Template.String(),
Details: modelDetails,
Messages: msgs,
ModifiedAt: manifest.fi.ModTime(),
@@ -1039,7 +1049,10 @@ func (s *Server) GenerateRoutes() http.Handler {
r.GET("/api/ps", s.ProcessHandler)
// Compatibility endpoints
r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
@@ -1246,7 +1259,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return runner.llama.Tokenize(ctx, s)
}
@@ -1294,8 +1307,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
if model.IsEmbedding() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
return
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
@@ -105,6 +106,24 @@ func Test_Routes(t *testing.T) {
assert.Empty(t, len(modelList.Models))
},
},
{
Name: "openai empty list",
Method: http.MethodGet,
Path: "/v1/models",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var modelList openai.ListCompletion
err = json.Unmarshal(body, &modelList)
require.NoError(t, err)
assert.Equal(t, "list", modelList.Object)
assert.Empty(t, modelList.Data)
},
},
{
Name: "Tags Handler (yes tags)",
Method: http.MethodGet,
@@ -128,6 +147,25 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
},
},
{
Name: "openai list models with tags",
Method: http.MethodGet,
Path: "/v1/models",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var modelList openai.ListCompletion
err = json.Unmarshal(body, &modelList)
require.NoError(t, err)
assert.Len(t, modelList.Data, 1)
assert.Equal(t, "test-model:latest", modelList.Data[0].Id)
assert.Equal(t, "library", modelList.Data[0].OwnedBy)
},
},
{
Name: "Create Model Handler",
Method: http.MethodPost,
@@ -216,6 +254,24 @@ func Test_Routes(t *testing.T) {
assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0")
},
},
{
Name: "openai retrieve model handler",
Method: http.MethodGet,
Path: "/v1/models/show-model",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var retrieveResp api.RetrieveModelResponse
err = json.Unmarshal(body, &retrieveResp)
require.NoError(t, err)
assert.Equal(t, "show-model", retrieveResp.Id)
assert.Equal(t, "library", retrieveResp.OwnedBy)
},
},
}
t.Setenv("OLLAMA_MODELS", t.TempDir())

158
template/template.go Normal file
View File

@@ -0,0 +1,158 @@
package template
import (
"bytes"
"embed"
"encoding/json"
"errors"
"io"
"math"
"slices"
"strings"
"sync"
"text/template"
"text/template/parse"
"github.com/agnivade/levenshtein"
"golang.org/x/exp/maps"
)
//go:embed index.json
var indexBytes []byte
//go:embed *.gotmpl
var templatesFS embed.FS
var templatesOnce = sync.OnceValues(func() ([]*named, error) {
var templates []*named
if err := json.Unmarshal(indexBytes, &templates); err != nil {
return nil, err
}
for _, t := range templates {
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
if err != nil {
return nil, err
}
// normalize line endings
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
}
return templates, nil
})
type named struct {
Name string `json:"name"`
Template string `json:"template"`
Bytes []byte
}
func (t named) Reader() io.Reader {
return bytes.NewReader(t.Bytes)
}
func Named(s string) (*named, error) {
templates, err := templatesOnce()
if err != nil {
return nil, err
}
var template *named
score := math.MaxInt
for _, t := range templates {
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
score = s
template = t
}
}
if score < 100 {
return template, nil
}
return nil, errors.New("no matching template found")
}
type Template struct {
*template.Template
raw string
}
func (t *Template) String() string {
return t.raw
}
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
func Parse(s string) (*Template, error) {
t, err := template.New("").Option("missingkey=zero").Parse(s)
if err != nil {
return nil, err
}
return &Template{Template: t, raw: s}, nil
}
func (t *Template) Vars() []string {
var vars []string
for _, n := range t.Tree.Root.Nodes {
vars = append(vars, parseNode(n)...)
}
set := make(map[string]struct{})
for _, n := range vars {
set[strings.ToLower(n)] = struct{}{}
}
vars = maps.Keys(set)
slices.Sort(vars)
return vars
}
func parseNode(n parse.Node) []string {
switch n := n.(type) {
case *parse.ActionNode:
return parseNode(n.Pipe)
case *parse.IfNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.RangeNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.WithNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
case *parse.PipeNode:
var names []string
for _, c := range n.Cmds {
for _, a := range c.Args {
names = append(names, parseNode(a)...)
}
}
return names
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, parseNode(n)...)
}
return names
case *parse.FieldNode:
return n.Ident
}
return nil
}

89
template/template_test.go Normal file
View File

@@ -0,0 +1,89 @@
package template
import (
"bufio"
"bytes"
"encoding/json"
"io"
"os"
"path/filepath"
"slices"
"testing"
"text/template"
"github.com/ollama/ollama/llm"
)
func TestNamed(t *testing.T) {
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
var ss map[string]string
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
t.Fatal(err)
}
for k, v := range ss {
t.Run(k, func(t *testing.T) {
kv := llm.KV{"tokenizer.chat_template": v}
s := kv.ChatTemplate()
r, err := Named(s)
if err != nil {
t.Fatal(err)
}
if r.Name != k {
t.Errorf("expected %q, got %q", k, r.Name)
}
var b bytes.Buffer
if _, err := io.Copy(&b, r.Reader()); err != nil {
t.Fatal(err)
}
tmpl, err := template.New(s).Parse(b.String())
if err != nil {
t.Fatal(err)
}
if tmpl.Tree.Root.String() == "" {
t.Errorf("empty %s template", k)
}
})
}
}
}
func TestParse(t *testing.T) {
cases := []struct {
template string
vars []string
}{
{"{{ .Prompt }}", []string{"prompt"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
tmpl, err := Parse(tt.template)
if err != nil {
t.Fatal(err)
}
vars := tmpl.Vars()
if !slices.Equal(tt.vars, vars) {
t.Errorf("expected %v, got %v", tt.vars, vars)
}
})
}
}

View File

@@ -1,70 +0,0 @@
package templates
import (
"bytes"
"embed"
"encoding/json"
"errors"
"io"
"math"
"sync"
"github.com/agnivade/levenshtein"
)
//go:embed index.json
var indexBytes []byte
//go:embed *.gotmpl
var templatesFS embed.FS
var templatesOnce = sync.OnceValues(func() ([]*Template, error) {
var templates []*Template
if err := json.Unmarshal(indexBytes, &templates); err != nil {
return nil, err
}
for _, t := range templates {
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
if err != nil {
return nil, err
}
// normalize line endings
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
}
return templates, nil
})
type Template struct {
Name string `json:"name"`
Template string `json:"template"`
Bytes []byte
}
func (t Template) Reader() io.Reader {
return bytes.NewReader(t.Bytes)
}
func NamedTemplate(s string) (*Template, error) {
templates, err := templatesOnce()
if err != nil {
return nil, err
}
var template *Template
score := math.MaxInt
for _, t := range templates {
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
score = s
template = t
}
}
if score < 100 {
return template, nil
}
return nil, errors.New("no matching template found")
}

View File

@@ -1,59 +0,0 @@
package templates
import (
"bufio"
"bytes"
"encoding/json"
"io"
"os"
"path/filepath"
"testing"
"text/template"
"github.com/ollama/ollama/llm"
)
func TestKVChatTemplate(t *testing.T) {
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
var ss map[string]string
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
t.Fatal(err)
}
for k, v := range ss {
t.Run(k, func(t *testing.T) {
kv := llm.KV{"tokenizer.chat_template": v}
s := kv.ChatTemplate()
r, err := NamedTemplate(s)
if err != nil {
t.Fatal(err)
}
if r.Name != k {
t.Errorf("expected %q, got %q", k, r.Name)
}
var b bytes.Buffer
if _, err := io.Copy(&b, r.Reader()); err != nil {
t.Fatal(err)
}
tmpl, err := template.New(s).Parse(b.String())
if err != nil {
t.Fatal(err)
}
if tmpl.Tree.Root.String() == "" {
t.Errorf("empty %s template", k)
}
})
}
}
}