mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
Merge branch 'ollama:main' into main
This commit is contained in:
36
Dockerfile
36
Dockerfile
@@ -1,6 +1,7 @@
|
|||||||
# vim: filetype=dockerfile
|
# vim: filetype=dockerfile
|
||||||
|
|
||||||
ARG FLAVOR=${TARGETARCH}
|
ARG FLAVOR=${TARGETARCH}
|
||||||
|
ARG PARALLEL=8
|
||||||
|
|
||||||
ARG ROCMVERSION=6.3.3
|
ARG ROCMVERSION=6.3.3
|
||||||
ARG JETPACK5VERSION=r35.4.1
|
ARG JETPACK5VERSION=r35.4.1
|
||||||
@@ -34,46 +35,51 @@ ENV LDFLAGS=-s
|
|||||||
FROM base AS cpu
|
FROM base AS cpu
|
||||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CPU' \
|
cmake --preset 'CPU' \
|
||||||
&& cmake --build --parallel --preset 'CPU' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
||||||
&& cmake --install build --component CPU --strip --parallel 8
|
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
FROM base AS cuda-11
|
FROM base AS cuda-11
|
||||||
ARG CUDA11VERSION=11.8
|
ARG CUDA11VERSION=11.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
|
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
|
||||||
&& cmake --build --parallel --preset 'CUDA 11' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
FROM base AS cuda-12
|
FROM base AS cuda-12
|
||||||
ARG CUDA12VERSION=12.8
|
ARG CUDA12VERSION=12.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
|
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
|
||||||
&& cmake --build --parallel --preset 'CUDA 12' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
|
|
||||||
FROM base AS cuda-13
|
FROM base AS cuda-13
|
||||||
ARG CUDA13VERSION=13.0
|
ARG CUDA13VERSION=13.0
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
|
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
|
||||||
&& cmake --build --parallel --preset 'CUDA 13' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
|
|
||||||
FROM base AS rocm-6
|
FROM base AS rocm-6
|
||||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'ROCm 6' \
|
cmake --preset 'ROCm 6' \
|
||||||
&& cmake --build --parallel --preset 'ROCm 6' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
||||||
&& cmake --install build --component HIP --strip --parallel 8
|
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
@@ -81,10 +87,11 @@ RUN apt-get update && apt-get install -y curl ccache \
|
|||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 5' \
|
cmake --preset 'JetPack 5' \
|
||||||
&& cmake --build --parallel --preset 'JetPack 5' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
|
||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||||
ARG CMAKEVERSION
|
ARG CMAKEVERSION
|
||||||
@@ -92,10 +99,11 @@ RUN apt-get update && apt-get install -y curl ccache \
|
|||||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||||
COPY CMakeLists.txt CMakePresets.json .
|
COPY CMakeLists.txt CMakePresets.json .
|
||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
|
ARG PARALLEL
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'JetPack 6' \
|
cmake --preset 'JetPack 6' \
|
||||||
&& cmake --build --parallel --preset 'JetPack 6' \
|
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
|
||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||||
|
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
|
|||||||
@@ -222,7 +222,17 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
return fmt.Errorf("unmarshal: %w", err)
|
return fmt.Errorf("unmarshal: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.StatusCode >= http.StatusBadRequest {
|
if response.StatusCode == http.StatusUnauthorized {
|
||||||
|
pubKey, pkErr := auth.GetPublicKey()
|
||||||
|
if pkErr != nil {
|
||||||
|
return pkErr
|
||||||
|
}
|
||||||
|
return AuthorizationError{
|
||||||
|
StatusCode: response.StatusCode,
|
||||||
|
Status: response.Status,
|
||||||
|
PublicKey: pubKey,
|
||||||
|
}
|
||||||
|
} else if response.StatusCode >= http.StatusBadRequest {
|
||||||
return StatusError{
|
return StatusError{
|
||||||
StatusCode: response.StatusCode,
|
StatusCode: response.StatusCode,
|
||||||
Status: response.Status,
|
Status: response.Status,
|
||||||
@@ -428,3 +438,16 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
|||||||
|
|
||||||
return version.Version, nil
|
return version.Version, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Signout will disconnect an ollama instance from ollama.com
|
||||||
|
func (c *Client) Signout(ctx context.Context, encodedKey string) error {
|
||||||
|
return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||||
|
var resp UserResponse
|
||||||
|
if err := c.do(ctx, http.MethodPost, "/api/me", nil, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|||||||
133
api/types.go
133
api/types.go
@@ -11,6 +11,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
@@ -36,6 +38,19 @@ func (e StatusError) Error() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AuthorizationError struct {
|
||||||
|
StatusCode int
|
||||||
|
Status string
|
||||||
|
PublicKey string `json:"public_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e AuthorizationError) Error() string {
|
||||||
|
if e.Status != "" {
|
||||||
|
return e.Status
|
||||||
|
}
|
||||||
|
return "something went wrong, please see the ollama server logs for details"
|
||||||
|
}
|
||||||
|
|
||||||
// ImageData represents the raw binary data of an image file.
|
// ImageData represents the raw binary data of an image file.
|
||||||
type ImageData []byte
|
type ImageData []byte
|
||||||
|
|
||||||
@@ -313,13 +328,29 @@ func (t *ToolFunction) String() string {
|
|||||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||||
// similar to [GenerateResponse].
|
// similar to [GenerateResponse].
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
Model string `json:"model"`
|
// Model is the model name that generated the response.
|
||||||
CreatedAt time.Time `json:"created_at"`
|
Model string `json:"model"`
|
||||||
Message Message `json:"message"`
|
|
||||||
DoneReason string `json:"done_reason,omitempty"`
|
|
||||||
|
|
||||||
|
// RemoteModel is the name of the upstream model that generated the response.
|
||||||
|
RemoteModel string `json:"remote_model,omitempty"`
|
||||||
|
|
||||||
|
// RemoteHost is the URL of the upstream Ollama host that generated the response.
|
||||||
|
RemoteHost string `json:"remote_host,omitempty"`
|
||||||
|
|
||||||
|
// CreatedAt is the timestamp of the response.
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
|
// Message contains the message or part of a message from the model.
|
||||||
|
Message Message `json:"message"`
|
||||||
|
|
||||||
|
// Done specifies if the response is complete.
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
|
||||||
|
// DoneReason is the reason the model stopped generating text.
|
||||||
|
DoneReason string `json:"done_reason,omitempty"`
|
||||||
|
|
||||||
|
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||||
|
|
||||||
Metrics
|
Metrics
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -329,13 +360,6 @@ type DebugInfo struct {
|
|||||||
ImageCount int `json:"image_count,omitempty"`
|
ImageCount int `json:"image_count,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DebugTemplateResponse is returned when _debug_render_only is set to true
|
|
||||||
type DebugTemplateResponse struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
DebugInfo DebugInfo `json:"_debug_info"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||||
@@ -431,18 +455,47 @@ type EmbeddingResponse struct {
|
|||||||
|
|
||||||
// CreateRequest is the request passed to [Client.Create].
|
// CreateRequest is the request passed to [Client.Create].
|
||||||
type CreateRequest struct {
|
type CreateRequest struct {
|
||||||
Model string `json:"model"`
|
// Model is the model name to create.
|
||||||
Stream *bool `json:"stream,omitempty"`
|
Model string `json:"model"`
|
||||||
|
|
||||||
|
// Stream specifies whether the response is streaming; it is true by default.
|
||||||
|
Stream *bool `json:"stream,omitempty"`
|
||||||
|
|
||||||
|
// Quantize is the quantization format for the model; leave blank to not change the quantization level.
|
||||||
Quantize string `json:"quantize,omitempty"`
|
Quantize string `json:"quantize,omitempty"`
|
||||||
|
|
||||||
From string `json:"from,omitempty"`
|
// From is the name of the model or file to use as the source.
|
||||||
Files map[string]string `json:"files,omitempty"`
|
From string `json:"from,omitempty"`
|
||||||
Adapters map[string]string `json:"adapters,omitempty"`
|
|
||||||
Template string `json:"template,omitempty"`
|
// RemoteHost is the URL of the upstream ollama API for the model (if any).
|
||||||
License any `json:"license,omitempty"`
|
RemoteHost string `json:"remote_host,omitempty"`
|
||||||
System string `json:"system,omitempty"`
|
|
||||||
Parameters map[string]any `json:"parameters,omitempty"`
|
// Files is a map of files include when creating the model.
|
||||||
Messages []Message `json:"messages,omitempty"`
|
Files map[string]string `json:"files,omitempty"`
|
||||||
|
|
||||||
|
// Adapters is a map of LoRA adapters to include when creating the model.
|
||||||
|
Adapters map[string]string `json:"adapters,omitempty"`
|
||||||
|
|
||||||
|
// Template is the template used when constructing a request to the model.
|
||||||
|
Template string `json:"template,omitempty"`
|
||||||
|
|
||||||
|
// License is a string or list of strings for licenses.
|
||||||
|
License any `json:"license,omitempty"`
|
||||||
|
|
||||||
|
// System is the system prompt for the model.
|
||||||
|
System string `json:"system,omitempty"`
|
||||||
|
|
||||||
|
// Parameters is a map of hyper-parameters which are applied to the model.
|
||||||
|
Parameters map[string]any `json:"parameters,omitempty"`
|
||||||
|
|
||||||
|
// Messages is a list of messages added to the model before chat and generation requests.
|
||||||
|
Messages []Message `json:"messages,omitempty"`
|
||||||
|
|
||||||
|
Renderer string `json:"renderer,omitempty"`
|
||||||
|
Parser string `json:"parser,omitempty"`
|
||||||
|
|
||||||
|
// Info is a map of additional information for the model
|
||||||
|
Info map[string]any `json:"info,omitempty"`
|
||||||
|
|
||||||
// Deprecated: set the model name with Model instead
|
// Deprecated: set the model name with Model instead
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@@ -480,8 +533,12 @@ type ShowResponse struct {
|
|||||||
Parameters string `json:"parameters,omitempty"`
|
Parameters string `json:"parameters,omitempty"`
|
||||||
Template string `json:"template,omitempty"`
|
Template string `json:"template,omitempty"`
|
||||||
System string `json:"system,omitempty"`
|
System string `json:"system,omitempty"`
|
||||||
|
Renderer string `json:"renderer,omitempty"`
|
||||||
|
Parser string `json:"parser,omitempty"`
|
||||||
Details ModelDetails `json:"details,omitempty"`
|
Details ModelDetails `json:"details,omitempty"`
|
||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
|
RemoteModel string `json:"remote_model,omitempty"`
|
||||||
|
RemoteHost string `json:"remote_host,omitempty"`
|
||||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||||
Tensors []Tensor `json:"tensors,omitempty"`
|
Tensors []Tensor `json:"tensors,omitempty"`
|
||||||
@@ -540,12 +597,14 @@ type ProcessResponse struct {
|
|||||||
|
|
||||||
// ListModelResponse is a single model description in [ListResponse].
|
// ListModelResponse is a single model description in [ListResponse].
|
||||||
type ListModelResponse struct {
|
type ListModelResponse struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
ModifiedAt time.Time `json:"modified_at"`
|
RemoteModel string `json:"remote_model,omitempty"`
|
||||||
Size int64 `json:"size"`
|
RemoteHost string `json:"remote_host,omitempty"`
|
||||||
Digest string `json:"digest"`
|
ModifiedAt time.Time `json:"modified_at"`
|
||||||
Details ModelDetails `json:"details,omitempty"`
|
Size int64 `json:"size"`
|
||||||
|
Digest string `json:"digest"`
|
||||||
|
Details ModelDetails `json:"details,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessModelResponse is a single model description in [ProcessResponse].
|
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||||
@@ -569,6 +628,12 @@ type GenerateResponse struct {
|
|||||||
// Model is the model name that generated the response.
|
// Model is the model name that generated the response.
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|
||||||
|
// RemoteModel is the name of the upstream model that generated the response.
|
||||||
|
RemoteModel string `json:"remote_model,omitempty"`
|
||||||
|
|
||||||
|
// RemoteHost is the URL of the upstream Ollama host that generated the response.
|
||||||
|
RemoteHost string `json:"remote_host,omitempty"`
|
||||||
|
|
||||||
// CreatedAt is the timestamp of the response.
|
// CreatedAt is the timestamp of the response.
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
@@ -592,6 +657,8 @@ type GenerateResponse struct {
|
|||||||
Metrics
|
Metrics
|
||||||
|
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
|
||||||
|
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelDetails provides details about a model.
|
// ModelDetails provides details about a model.
|
||||||
@@ -604,6 +671,18 @@ type ModelDetails struct {
|
|||||||
QuantizationLevel string `json:"quantization_level"`
|
QuantizationLevel string `json:"quantization_level"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserResponse provides information about a user.
|
||||||
|
type UserResponse struct {
|
||||||
|
ID uuid.UUID `json:"id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Bio string `json:"bio,omitempty"`
|
||||||
|
AvatarURL string `json:"avatarurl,omitempty"`
|
||||||
|
FirstName string `json:"firstname,omitempty"`
|
||||||
|
LastName string `json:"lastname,omitempty"`
|
||||||
|
Plan string `json:"plan,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// Tensor describes the metadata for a given tensor.
|
// Tensor describes the metadata for a given tensor.
|
||||||
type Tensor struct {
|
type Tensor struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|||||||
25
auth/auth.go
25
auth/auth.go
@@ -19,6 +19,31 @@ import (
|
|||||||
const defaultPrivateKey = "id_ed25519"
|
const defaultPrivateKey = "id_ed25519"
|
||||||
|
|
||||||
func keyPath() (string, error) {
|
func keyPath() (string, error) {
|
||||||
|
fileIsReadable := func(fp string) bool {
|
||||||
|
info, err := os.Stat(fp)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that it's a regular file, not a directory or other file type
|
||||||
|
if !info.Mode().IsRegular() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to open it to check readability
|
||||||
|
file, err := os.Open(fp)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
file.Close()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey)
|
||||||
|
if fileIsReadable(systemPath) {
|
||||||
|
return systemPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|||||||
155
cmd/cmd.go
155
cmd/cmd.go
@@ -5,6 +5,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -14,6 +15,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -35,6 +37,7 @@ import (
|
|||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
@@ -47,6 +50,8 @@ import (
|
|||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n"
|
||||||
|
|
||||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
@@ -286,7 +291,17 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||||||
Think: opts.Think,
|
Think: opts.Think,
|
||||||
}
|
}
|
||||||
|
|
||||||
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
|
return client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
|
||||||
|
if r.RemoteModel != "" && opts.ShowConnect {
|
||||||
|
p.StopAndClear()
|
||||||
|
if strings.HasPrefix(r.RemoteHost, "https://ollama.com") {
|
||||||
|
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", r.RemoteModel)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", r.RemoteModel, r.RemoteHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func StopHandler(cmd *cobra.Command, args []string) error {
|
func StopHandler(cmd *cobra.Command, args []string) error {
|
||||||
@@ -307,9 +322,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
opts := runOptions{
|
opts := runOptions{
|
||||||
Model: args[0],
|
Model: args[0],
|
||||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
Options: map[string]any{},
|
Options: map[string]any{},
|
||||||
|
ShowConnect: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
format, err := cmd.Flags().GetString("format")
|
format, err := cmd.Flags().GetString("format")
|
||||||
@@ -367,6 +383,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
prompts = append([]string{string(in)}, prompts...)
|
prompts = append([]string{string(in)}, prompts...)
|
||||||
|
opts.ShowConnect = false
|
||||||
opts.WordWrap = false
|
opts.WordWrap = false
|
||||||
interactive = false
|
interactive = false
|
||||||
}
|
}
|
||||||
@@ -433,6 +450,21 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
if interactive {
|
if interactive {
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
|
var sErr api.AuthorizationError
|
||||||
|
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
||||||
|
pubKey, pkErr := auth.GetPublicKey()
|
||||||
|
if pkErr != nil {
|
||||||
|
return pkErr
|
||||||
|
}
|
||||||
|
// the server and the client both have the same public key
|
||||||
|
if pubKey == sErr.PublicKey {
|
||||||
|
h, _ := os.Hostname()
|
||||||
|
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||||
|
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||||
|
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -453,6 +485,56 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return generate(cmd, opts)
|
return generate(cmd, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := client.Whoami(cmd.Context())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if user != nil && user.Name != "" {
|
||||||
|
fmt.Printf("You are already signed in as user '%s'\n", user.Name)
|
||||||
|
fmt.Println()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey, pkErr := auth.GetPublicKey()
|
||||||
|
if pkErr != nil {
|
||||||
|
return pkErr
|
||||||
|
}
|
||||||
|
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||||
|
|
||||||
|
h, _ := os.Hostname()
|
||||||
|
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignoutHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
pubKey, pkErr := auth.GetPublicKey()
|
||||||
|
if pkErr != nil {
|
||||||
|
return pkErr
|
||||||
|
}
|
||||||
|
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||||
|
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = client.Signout(cmd.Context(), encKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("You have signed out of ollama.com")
|
||||||
|
fmt.Println()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -505,7 +587,8 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if spinner != nil {
|
if spinner != nil {
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
}
|
}
|
||||||
if strings.Contains(err.Error(), "access denied") {
|
errStr := strings.ToLower(err.Error())
|
||||||
|
if strings.Contains(errStr, "access denied") || strings.Contains(errStr, "unauthorized") {
|
||||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -539,7 +622,14 @@ func ListHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
for _, m := range models.Models {
|
for _, m := range models.Models {
|
||||||
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
|
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
|
||||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
|
var size string
|
||||||
|
if m.RemoteModel != "" {
|
||||||
|
size = "-"
|
||||||
|
} else {
|
||||||
|
size = format.HumanBytes(m.Size)
|
||||||
|
}
|
||||||
|
|
||||||
|
data = append(data, []string{m.Name, m.Digest[:12], size, format.HumanTime(m.ModifiedAt, "Never")})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -624,8 +714,8 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
|||||||
KeepAlive: &api.Duration{Duration: 0},
|
KeepAlive: &api.Duration{Duration: 0},
|
||||||
}
|
}
|
||||||
if err := loadOrUnloadModel(cmd, opts); err != nil {
|
if err := loadOrUnloadModel(cmd, opts); err != nil {
|
||||||
if !strings.Contains(err.Error(), "not found") {
|
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||||
return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err)
|
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -736,12 +826,36 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tableRender("Model", func() (rows [][]string) {
|
tableRender("Model", func() (rows [][]string) {
|
||||||
|
if resp.RemoteHost != "" {
|
||||||
|
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
|
||||||
|
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
|
||||||
|
}
|
||||||
|
|
||||||
if resp.ModelInfo != nil {
|
if resp.ModelInfo != nil {
|
||||||
arch := resp.ModelInfo["general.architecture"].(string)
|
arch := resp.ModelInfo["general.architecture"].(string)
|
||||||
rows = append(rows, []string{"", "architecture", arch})
|
rows = append(rows, []string{"", "architecture", arch})
|
||||||
rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))})
|
|
||||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)})
|
var paramStr string
|
||||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)})
|
if resp.Details.ParameterSize != "" {
|
||||||
|
paramStr = resp.Details.ParameterSize
|
||||||
|
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
|
||||||
|
if f, ok := v.(float64); ok {
|
||||||
|
paramStr = format.HumanNumber(uint64(f))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rows = append(rows, []string{"", "parameters", paramStr})
|
||||||
|
|
||||||
|
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||||
|
if f, ok := v.(float64); ok {
|
||||||
|
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
|
||||||
|
if f, ok := v.(float64); ok {
|
||||||
|
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
rows = append(rows, []string{"", "architecture", resp.Details.Family})
|
rows = append(rows, []string{"", "architecture", resp.Details.Family})
|
||||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||||
@@ -989,6 +1103,7 @@ type runOptions struct {
|
|||||||
KeepAlive *api.Duration
|
KeepAlive *api.Duration
|
||||||
Think *api.ThinkValue
|
Think *api.ThinkValue
|
||||||
HideThinking bool
|
HideThinking bool
|
||||||
|
ShowConnect bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type displayResponseState struct {
|
type displayResponseState struct {
|
||||||
@@ -1544,6 +1659,22 @@ func NewCLI() *cobra.Command {
|
|||||||
|
|
||||||
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||||
|
|
||||||
|
signinCmd := &cobra.Command{
|
||||||
|
Use: "signin",
|
||||||
|
Short: "Sign in to ollama.com",
|
||||||
|
Args: cobra.ExactArgs(0),
|
||||||
|
PreRunE: checkServerHeartbeat,
|
||||||
|
RunE: SigninHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
signoutCmd := &cobra.Command{
|
||||||
|
Use: "signout",
|
||||||
|
Short: "Sign out from ollama.com",
|
||||||
|
Args: cobra.ExactArgs(0),
|
||||||
|
PreRunE: checkServerHeartbeat,
|
||||||
|
RunE: SignoutHandler,
|
||||||
|
}
|
||||||
|
|
||||||
listCmd := &cobra.Command{
|
listCmd := &cobra.Command{
|
||||||
Use: "list",
|
Use: "list",
|
||||||
Aliases: []string{"ls"},
|
Aliases: []string{"ls"},
|
||||||
@@ -1638,6 +1769,8 @@ func NewCLI() *cobra.Command {
|
|||||||
stopCmd,
|
stopCmd,
|
||||||
pullCmd,
|
pullCmd,
|
||||||
pushCmd,
|
pushCmd,
|
||||||
|
signinCmd,
|
||||||
|
signoutCmd,
|
||||||
listCmd,
|
listCmd,
|
||||||
psCmd,
|
psCmd,
|
||||||
copyCmd,
|
copyCmd,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -304,6 +305,8 @@ func TestDeleteHandler(t *testing.T) {
|
|||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
} else {
|
} else {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
errPayload := `{"error":"model '%s' not found"}`
|
||||||
|
w.Write([]byte(fmt.Sprintf(errPayload, req.Name)))
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -346,7 +349,7 @@ func TestDeleteHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err := DeleteHandler(cmd, []string{"test-model-not-found"})
|
err := DeleteHandler(cmd, []string{"test-model-not-found"})
|
||||||
if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") {
|
if err == nil || !strings.Contains(err.Error(), "model 'test-model-not-found' not found") {
|
||||||
t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
|
t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -499,7 +502,7 @@ func TestPushHandler(t *testing.T) {
|
|||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
err := json.NewEncoder(w).Encode(map[string]string{
|
err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
"error": "access denied",
|
"error": "403: {\"errors\":[{\"code\":\"ACCESS DENIED\", \"message\":\"access denied\"}]}",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -522,6 +525,7 @@ func TestPushHandler(t *testing.T) {
|
|||||||
defer mockServer.Close()
|
defer mockServer.Close()
|
||||||
|
|
||||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
initializeKeypair()
|
||||||
|
|
||||||
cmd := &cobra.Command{}
|
cmd := &cobra.Command{}
|
||||||
cmd.Flags().Bool("insecure", false, "")
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ type bertModel struct {
|
|||||||
LayerNormEPS float32 `json:"layer_norm_eps"`
|
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||||
NormEpsilon float32 `json:"norm_epsilon"`
|
NormEpsilon float32 `json:"norm_epsilon"`
|
||||||
|
normalizeEmbeddings bool
|
||||||
|
|
||||||
PoolingType uint32
|
PoolingType uint32
|
||||||
}
|
}
|
||||||
@@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
|||||||
|
|
||||||
var pooling string
|
var pooling string
|
||||||
for _, m := range modules {
|
for _, m := range modules {
|
||||||
if m.Type == "sentence_transformers.models.Pooling" {
|
switch m.Type {
|
||||||
|
case "sentence_transformers.models.Pooling":
|
||||||
pooling = m.Path
|
pooling = m.Path
|
||||||
break
|
case "sentence_transformers.models.Normalize":
|
||||||
|
p.normalizeEmbeddings = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["general.architecture"] = "bert"
|
kv["general.architecture"] = "bert"
|
||||||
kv["bert.attention.causal"] = false
|
kv["bert.attention.causal"] = false
|
||||||
kv["bert.pooling_type"] = p.PoolingType
|
kv["bert.pooling_type"] = p.PoolingType
|
||||||
|
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
|
||||||
|
|
||||||
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
||||||
|
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ type safetensor struct {
|
|||||||
|
|
||||||
func (st safetensor) Kind() uint32 {
|
func (st safetensor) Kind() uint32 {
|
||||||
kind := st.tensorBase.Kind()
|
kind := st.tensorBase.Kind()
|
||||||
if st.dtype == "BF16" && kind != tensorKindFP32 {
|
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
|
||||||
kind = tensorKindBF16
|
kind = tensorKindBF16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -230,3 +230,65 @@ func TestSafetensors(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSafetensorKind(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
st safetensor
|
||||||
|
expected uint32
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16",
|
||||||
|
st: safetensor{
|
||||||
|
tensorBase: &tensorBase{
|
||||||
|
name: "weight.matrix",
|
||||||
|
shape: []uint64{10, 10}, // will default to FP16
|
||||||
|
},
|
||||||
|
dtype: "BF16",
|
||||||
|
},
|
||||||
|
expected: tensorKindBF16,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "BF16 dtype with v. prefix should return base kind",
|
||||||
|
st: safetensor{
|
||||||
|
tensorBase: &tensorBase{
|
||||||
|
name: "v.weight.matrix",
|
||||||
|
shape: []uint64{10, 10}, // will default to FP16
|
||||||
|
},
|
||||||
|
dtype: "BF16",
|
||||||
|
},
|
||||||
|
expected: tensorKindFP16,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "BF16 dtype with FP32 base kind should return FP32",
|
||||||
|
st: safetensor{
|
||||||
|
tensorBase: &tensorBase{
|
||||||
|
name: "weight.matrix",
|
||||||
|
shape: []uint64{10}, // will default to FP32
|
||||||
|
},
|
||||||
|
dtype: "BF16",
|
||||||
|
},
|
||||||
|
expected: tensorKindFP32,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-BF16 dtype should return base kind",
|
||||||
|
st: safetensor{
|
||||||
|
tensorBase: &tensorBase{
|
||||||
|
name: "weight.matrix",
|
||||||
|
shape: []uint64{10, 10}, // will default to FP16
|
||||||
|
},
|
||||||
|
dtype: "FP16",
|
||||||
|
},
|
||||||
|
expected: tensorKindFP16,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.st.Kind()
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Kind() = %d, expected %d", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||||
|
|
||||||
func cudaVariant(gpuInfo CudaGPUInfo) string {
|
func cudaVariant(gpuInfos []CudaGPUInfo) string {
|
||||||
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
||||||
if CudaTegra != "" {
|
if CudaTegra != "" {
|
||||||
ver := strings.Split(CudaTegra, ".")
|
ver := strings.Split(CudaTegra, ".")
|
||||||
@@ -45,12 +45,19 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if gpuInfo.DriverMajor < 13 {
|
// Check GPU compute capability FIRST, lowest common denominator if multi-gpu
|
||||||
|
for _, gpuInfo := range gpuInfos {
|
||||||
|
if gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) {
|
||||||
|
// GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1)
|
||||||
|
return "v12"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GPU is Turing or newer (CC >= 7.5) - can use newer CUDA
|
||||||
|
if len(gpuInfos) > 0 && gpuInfos[0].DriverMajor < 13 {
|
||||||
// The detected driver is older than 580 (Aug 2025)
|
// The detected driver is older than 580 (Aug 2025)
|
||||||
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
|
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
|
||||||
if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) {
|
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfos[0].DriverMajor, gpuInfos[0].DriverMinor))
|
||||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
|
||||||
}
|
|
||||||
return "v12"
|
return "v12"
|
||||||
}
|
}
|
||||||
return "v13"
|
return "v13"
|
||||||
|
|||||||
@@ -284,18 +284,8 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
gpuInfo.MinimumMemory = cudaMinimumMemory
|
gpuInfo.MinimumMemory = cudaMinimumMemory
|
||||||
gpuInfo.DriverMajor = driverMajor
|
gpuInfo.DriverMajor = driverMajor
|
||||||
gpuInfo.DriverMinor = driverMinor
|
gpuInfo.DriverMinor = driverMinor
|
||||||
variant := cudaVariant(gpuInfo)
|
|
||||||
|
|
||||||
// Start with our bundled libraries
|
|
||||||
if variant != "" {
|
|
||||||
variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant)
|
|
||||||
if _, err := os.Stat(variantPath); err == nil {
|
|
||||||
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
|
|
||||||
gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||||
gpuInfo.Variant = variant
|
|
||||||
|
|
||||||
if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) {
|
if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) {
|
||||||
unsupportedGPUs = append(unsupportedGPUs,
|
unsupportedGPUs = append(unsupportedGPUs,
|
||||||
@@ -333,6 +323,24 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||||
cudaGPUs = append(cudaGPUs, gpuInfo)
|
cudaGPUs = append(cudaGPUs, gpuInfo)
|
||||||
}
|
}
|
||||||
|
// Second pass on NVIDIA GPUs to set lowest common denominator variant and DependencyPaths
|
||||||
|
variant := cudaVariant(cudaGPUs)
|
||||||
|
var variantPath string
|
||||||
|
// Start with our bundled libraries
|
||||||
|
if variant != "" {
|
||||||
|
variantPath = filepath.Join(LibOllamaPath, "cuda_"+variant)
|
||||||
|
if _, err := os.Stat(variantPath); err != nil {
|
||||||
|
variantPath = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range cudaGPUs {
|
||||||
|
cudaGPUs[i].Variant = variant
|
||||||
|
if variantPath != "" {
|
||||||
|
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
|
||||||
|
cudaGPUs[i].DependencyPath = append([]string{variantPath}, cudaGPUs[i].DependencyPath...)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Intel
|
// Intel
|
||||||
|
|||||||
@@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
|
|||||||
go run . serve
|
go run . serve
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
|
||||||
|
|
||||||
|
|
||||||
## macOS (Apple Silicon)
|
## macOS (Apple Silicon)
|
||||||
|
|
||||||
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
|
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
|
||||||
|
|||||||
@@ -134,6 +134,17 @@ func LoadTimeout() (loadTimeout time.Duration) {
|
|||||||
return loadTimeout
|
return loadTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Remotes() []string {
|
||||||
|
var r []string
|
||||||
|
raw := strings.TrimSpace(Var("OLLAMA_REMOTES"))
|
||||||
|
if raw == "" {
|
||||||
|
r = []string{"ollama.com"}
|
||||||
|
} else {
|
||||||
|
r = strings.Split(raw, ",")
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
func Bool(k string) func() bool {
|
func Bool(k string) func() bool {
|
||||||
return func() bool {
|
return func() bool {
|
||||||
if s := Var(k); s != "" {
|
if s := Var(k); s != "" {
|
||||||
@@ -270,6 +281,7 @@ func AsMap() map[string]EnvVar {
|
|||||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||||
|
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||||
|
|
||||||
// Informational
|
// Informational
|
||||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||||
|
|||||||
@@ -243,6 +243,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
|||||||
"gemma3",
|
"gemma3",
|
||||||
"gemma3n",
|
"gemma3n",
|
||||||
"mistral3",
|
"mistral3",
|
||||||
|
"qwen3",
|
||||||
"llama4",
|
"llama4",
|
||||||
"mllama",
|
"mllama",
|
||||||
"qwen25vl",
|
"qwen25vl",
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -44,9 +45,8 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(res.Embedding) != 384 {
|
if len(res.Embedding) != 384 {
|
||||||
@@ -74,9 +74,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, client, t, req)
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(res.Embeddings) != 1 {
|
if len(res.Embeddings) != 1 {
|
||||||
@@ -112,9 +111,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, client, t, req)
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(res.Embeddings) != 2 {
|
if len(res.Embeddings) != 2 {
|
||||||
@@ -156,93 +154,135 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
|
|
||||||
truncTrue, truncFalse := true, false
|
truncTrue, truncFalse := true, false
|
||||||
|
|
||||||
type testReq struct {
|
want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||||
Name string
|
Model: "all-minilm",
|
||||||
Request api.EmbedRequest
|
Input: "why",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqs := []testReq{
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
request api.EmbedRequest
|
||||||
|
check func(*api.EmbedResponse, error)
|
||||||
|
}{
|
||||||
{
|
{
|
||||||
Name: "Target Truncation",
|
name: "target truncation",
|
||||||
Request: api.EmbedRequest{
|
request: api.EmbedRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why",
|
Input: "why",
|
||||||
},
|
},
|
||||||
},
|
check: func(got *api.EmbedResponse, err error) {
|
||||||
{
|
if err != nil {
|
||||||
Name: "Default Truncate",
|
t.Fatal(err)
|
||||||
Request: api.EmbedRequest{
|
}
|
||||||
Model: "all-minilm",
|
|
||||||
Input: "why is the sky blue?",
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
Options: map[string]any{"num_ctx": 1},
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "Explicit Truncate",
|
name: "default truncate",
|
||||||
Request: api.EmbedRequest{
|
request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Options: map[string]any{"num_ctx": 3},
|
||||||
|
},
|
||||||
|
check: func(got *api.EmbedResponse, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit truncate",
|
||||||
|
request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 3},
|
||||||
|
},
|
||||||
|
check: func(got *api.EmbedResponse, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "truncate error",
|
||||||
|
request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 3},
|
||||||
|
},
|
||||||
|
check: func(res *api.EmbedResponse, err error) {
|
||||||
|
if err.Error() != "input exceeds maximum context length" {
|
||||||
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "input after truncate error",
|
||||||
|
request: api.EmbedRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why is the sky blue?",
|
Input: "why is the sky blue?",
|
||||||
Truncate: &truncTrue,
|
Truncate: &truncTrue,
|
||||||
Options: map[string]any{"num_ctx": 1},
|
Options: map[string]any{"num_ctx": 1},
|
||||||
},
|
},
|
||||||
|
check: func(res *api.EmbedResponse, err error) {
|
||||||
|
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||||
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "input after truncate error",
|
||||||
|
request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 0},
|
||||||
|
},
|
||||||
|
check: func(res *api.EmbedResponse, err error) {
|
||||||
|
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||||
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
res := make(map[string]*api.EmbedResponse)
|
for _, req := range cases {
|
||||||
|
t.Run(req.name, func(t *testing.T) {
|
||||||
for _, req := range reqs {
|
req.check(embedTestHelper(ctx, client, t, req.request))
|
||||||
response, err := embedTestHelper(ctx, client, t, req.Request)
|
})
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error: %v", err)
|
|
||||||
}
|
|
||||||
res[req.Name] = response
|
|
||||||
}
|
|
||||||
|
|
||||||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
|
||||||
t.Fatal("expected default request to truncate correctly")
|
|
||||||
}
|
|
||||||
|
|
||||||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
|
||||||
t.Fatal("expected default request and truncate true request to be the same")
|
|
||||||
}
|
|
||||||
|
|
||||||
// check that truncate set to false returns an error if context length is exceeded
|
|
||||||
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
|
||||||
Model: "all-minilm",
|
|
||||||
Input: "why is the sky blue?",
|
|
||||||
Truncate: &truncFalse,
|
|
||||||
Options: map[string]any{"num_ctx": 1},
|
|
||||||
})
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error, got nil")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.Embeddings(ctx, &req)
|
return client.Embeddings(ctx, &req)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.Embed(ctx, &req)
|
return client.Embed(ctx, &req)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const LevelTrace slog.Level = -8
|
const LevelTrace slog.Level = -8
|
||||||
@@ -29,10 +31,18 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type key string
|
||||||
|
|
||||||
func Trace(msg string, args ...any) {
|
func Trace(msg string, args ...any) {
|
||||||
slog.Log(context.TODO(), LevelTrace, msg, args...)
|
TraceContext(context.WithValue(context.TODO(), key("skip"), 1), msg, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TraceContext(ctx context.Context, msg string, args ...any) {
|
func TraceContext(ctx context.Context, msg string, args ...any) {
|
||||||
slog.Log(ctx, LevelTrace, msg, args...)
|
if logger := slog.Default(); logger.Enabled(ctx, LevelTrace) {
|
||||||
|
skip, _ := ctx.Value(key("skip")).(int)
|
||||||
|
pc, _, _, _ := runtime.Caller(1 + skip)
|
||||||
|
record := slog.NewRecord(time.Now(), LevelTrace, msg, pc)
|
||||||
|
record.Add(args...)
|
||||||
|
logger.Handler().Handle(ctx, record)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -416,6 +416,7 @@ type Tensor interface {
|
|||||||
AddID(ctx Context, t2, ids Tensor) Tensor
|
AddID(ctx Context, t2, ids Tensor) Tensor
|
||||||
|
|
||||||
Softmax(ctx Context) Tensor
|
Softmax(ctx Context) Tensor
|
||||||
|
L2Norm(ctx Context, eps float32) Tensor
|
||||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||||
Scale(ctx Context, s float64) Tensor
|
Scale(ctx Context, s float64) Tensor
|
||||||
@@ -429,12 +430,13 @@ type Tensor interface {
|
|||||||
Sin(ctx Context) Tensor
|
Sin(ctx Context) Tensor
|
||||||
Cos(ctx Context) Tensor
|
Cos(ctx Context) Tensor
|
||||||
Tanh(ctx Context) Tensor
|
Tanh(ctx Context) Tensor
|
||||||
GELU(ctx Context) Tensor
|
GELU(ctx Context, up ...Tensor) Tensor
|
||||||
QuickGELU(ctx Context) Tensor
|
SILU(ctx Context, up ...Tensor) Tensor
|
||||||
SILU(ctx Context) Tensor
|
RELU(ctx Context, up ...Tensor) Tensor
|
||||||
RELU(ctx Context) Tensor
|
|
||||||
Sigmoid(ctx Context) Tensor
|
Sigmoid(ctx Context) Tensor
|
||||||
SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor
|
|
||||||
|
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||||
|
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||||
|
|
||||||
Reshape(ctx Context, shape ...int) Tensor
|
Reshape(ctx Context, shape ...int) Tensor
|
||||||
View(ctx Context, offset int, shape ...int) Tensor
|
View(ctx Context, offset int, shape ...int) Tensor
|
||||||
|
|||||||
@@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||||
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
|
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
|
||||||
if w != nil {
|
if w != nil {
|
||||||
@@ -1424,35 +1431,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||||
|
if len(t2) > 0 {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||||
|
}
|
||||||
|
}
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
if len(t2) > 0 {
|
||||||
b: t.b,
|
return &Tensor{
|
||||||
t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t),
|
b: t.b,
|
||||||
|
t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||||
|
if len(t2) > 0 {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||||
|
}
|
||||||
|
}
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
|
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
|
func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),
|
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||||
|
ctx.Forward(query)
|
||||||
if key != nil && value != nil {
|
if key != nil && value != nil {
|
||||||
if query.Dim(0) != key.Dim(0) {
|
if query.Dim(0) != key.Dim(0) {
|
||||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||||
@@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
|
|||||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx.Forward(key, value)
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
cache.Put(ctx, key, value)
|
cache.Put(ctx, key, value)
|
||||||
}
|
}
|
||||||
|
|||||||
42
ml/nn/pooling/pooling.go
Normal file
42
ml/nn/pooling/pooling.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package pooling
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Type uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeNone Type = iota
|
||||||
|
TypeMean
|
||||||
|
TypeCLS
|
||||||
|
TypeLast
|
||||||
|
)
|
||||||
|
|
||||||
|
func (t Type) String() string {
|
||||||
|
switch t {
|
||||||
|
case TypeMean:
|
||||||
|
return "Mean"
|
||||||
|
case TypeCLS:
|
||||||
|
return "CLS"
|
||||||
|
case TypeLast:
|
||||||
|
return "Last"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||||
|
switch t {
|
||||||
|
case TypeMean:
|
||||||
|
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||||
|
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
|
case TypeCLS:
|
||||||
|
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
|
||||||
|
case TypeLast:
|
||||||
|
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
|
||||||
|
return hiddenStates
|
||||||
|
default:
|
||||||
|
panic("unknown pooling type")
|
||||||
|
}
|
||||||
|
}
|
||||||
79
ml/nn/pooling/pooling_test.go
Normal file
79
ml/nn/pooling/pooling_test.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package pooling_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ollama/ollama/discover"
|
||||||
|
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/backend/ggml"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setup(tb testing.TB, n int) ml.Backend {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if err := fsggml.WriteGGUF(f, fsggml.KV{
|
||||||
|
"general.architecture": "test",
|
||||||
|
"test.block_count": uint32(1),
|
||||||
|
}, []*fsggml.Tensor{
|
||||||
|
{Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
|
||||||
|
}); err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var gpuLayers ml.GPULayersList
|
||||||
|
if gpus := discover.GetGPUInfo(); len(gpus) > 0 {
|
||||||
|
gpuLayers = append(gpuLayers, ml.GPULayers{
|
||||||
|
ID: gpus[0].ID,
|
||||||
|
Layers: slices.Collect(func(yield func(int) bool) {
|
||||||
|
for i := range n {
|
||||||
|
if !yield(i) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers})
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForward(t *testing.T) {
|
||||||
|
cases := map[pooling.Type][]float32{
|
||||||
|
pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
|
||||||
|
pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7},
|
||||||
|
pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
|
||||||
|
}
|
||||||
|
for typ, want := range cases {
|
||||||
|
t.Run(typ.String(), func(t *testing.T) {
|
||||||
|
b := setup(t, 99)
|
||||||
|
defer b.Close()
|
||||||
|
|
||||||
|
ctx := b.NewContext()
|
||||||
|
defer ctx.Close()
|
||||||
|
|
||||||
|
tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
|
||||||
|
tt = typ.Forward(ctx, tt)
|
||||||
|
|
||||||
|
ctx.Forward(tt).Compute(tt)
|
||||||
|
if diff := cmp.Diff(want, tt.Floats()); diff != "" {
|
||||||
|
t.Error(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -54,10 +54,9 @@ type Batch struct {
|
|||||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||||
Inputs ml.Tensor
|
Inputs ml.Tensor
|
||||||
|
|
||||||
// Multimodal is a set of multimodal embeddings previously created by
|
// Outputs are the set of indicies into Inputs for which output data should
|
||||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
// be returned.
|
||||||
// models or for batches without multimodal elements.
|
Outputs ml.Tensor
|
||||||
Multimodal []MultimodalIndex
|
|
||||||
|
|
||||||
// Positions is the position for each Input, relative to its sequence. Equal
|
// Positions is the position for each Input, relative to its sequence. Equal
|
||||||
// in length to Inputs.
|
// in length to Inputs.
|
||||||
@@ -66,7 +65,8 @@ type Batch struct {
|
|||||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||||
Sequences []int
|
Sequences []int
|
||||||
|
|
||||||
// Outputs are the set of indicies into Inputs for which output data should
|
// Multimodal is a set of multimodal embeddings previously created by
|
||||||
// be returned.
|
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||||
Outputs []int32
|
// models or for batches without multimodal elements.
|
||||||
|
Multimodal []MultimodalIndex
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
_ "image/png"
|
_ "image/png"
|
||||||
"math"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -21,10 +20,15 @@ import (
|
|||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
_ "github.com/ollama/ollama/ml/backend"
|
_ "github.com/ollama/ollama/ml/backend"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
var (
|
||||||
|
ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||||
|
ErrUnsupportedModel = errors.New("model not supported")
|
||||||
|
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
|
||||||
|
)
|
||||||
|
|
||||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||||
type Model interface {
|
type Model interface {
|
||||||
@@ -103,23 +107,12 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
arch := b.Config().Architecture()
|
m, err := modelForArch(b.Config())
|
||||||
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
|
|
||||||
arch = arch + "_embed"
|
|
||||||
}
|
|
||||||
|
|
||||||
f, ok := models[arch]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := f(b.Config())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
base := Base{b: b, config: m.Config()}
|
base := Base{b: b, config: m.Config()}
|
||||||
|
|
||||||
v := reflect.ValueOf(m)
|
v := reflect.ValueOf(m)
|
||||||
v.Elem().Set(populateFields(base, v.Elem()))
|
v.Elem().Set(populateFields(base, v.Elem()))
|
||||||
return m, nil
|
return m, nil
|
||||||
@@ -131,30 +124,38 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
meta, err := fsggml.Decode(r, -1)
|
meta, err := fsggml.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return getTextProcessor(meta.KV())
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTextProcessor(kv fsggml.KV) (TextProcessor, error) {
|
m, err := modelForArch(meta.KV())
|
||||||
arch := kv.Architecture()
|
|
||||||
f, ok := models[arch]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
|
||||||
}
|
|
||||||
m, err := f(kv)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tp, ok := m.(TextProcessor)
|
tp, ok := m.(TextProcessor)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("%v is not a TextProcessor", m)
|
return nil, ErrUnsupportedTokenizer
|
||||||
}
|
}
|
||||||
return tp, nil
|
return tp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func modelForArch(c fs.Config) (Model, error) {
|
||||||
|
arch := c.Architecture()
|
||||||
|
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
|
||||||
|
arch = arch + "_embed"
|
||||||
|
}
|
||||||
|
|
||||||
|
f, ok := models[arch]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrUnsupportedModel
|
||||||
|
}
|
||||||
|
|
||||||
|
return f(c)
|
||||||
|
}
|
||||||
|
|
||||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||||
t := v.Type()
|
t := v.Type()
|
||||||
|
|
||||||
@@ -242,7 +243,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
|
|||||||
vv = vv.Elem()
|
vv = vv.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
vv = vv.Elem()
|
vv = reflect.Indirect(vv)
|
||||||
if v.IsNil() {
|
if v.IsNil() {
|
||||||
vv = reflect.New(v.Type().Elem()).Elem()
|
vv = reflect.New(v.Type().Elem()).Elem()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/backend/ggml"
|
"github.com/ollama/ollama/ml/backend/ggml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseTags(t *testing.T) {
|
func TestParseTags(t *testing.T) {
|
||||||
@@ -148,39 +147,58 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetTextProcessor(t *testing.T) {
|
func TestModelForArch(t *testing.T) {
|
||||||
tp, err := getTextProcessor(fsggml.KV{})
|
type fakeModel struct {
|
||||||
if err == nil {
|
Model
|
||||||
t.Error("expected error")
|
|
||||||
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
|
|
||||||
t.Errorf("unexpected error: %v", err)
|
|
||||||
} else if tp != nil {
|
|
||||||
t.Error("expected nil tp")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
models["dummy"] = func(fs.Config) (Model, error) {
|
type fakeEmbeddingModel struct {
|
||||||
return notTextProcessorModel{}, nil
|
Model
|
||||||
}
|
}
|
||||||
tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
|
|
||||||
if err == nil {
|
models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil }
|
||||||
t.Error("expected error")
|
models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil }
|
||||||
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
|
|
||||||
t.Errorf("unexpected error: %v", err)
|
cases := []struct {
|
||||||
} else if tp != nil {
|
name string
|
||||||
t.Error("expected nil tp")
|
config fs.Config
|
||||||
|
want any
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "model",
|
||||||
|
config: fsggml.KV{
|
||||||
|
"general.architecture": "model",
|
||||||
|
},
|
||||||
|
want: fakeModel{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "embedding",
|
||||||
|
config: fsggml.KV{
|
||||||
|
"general.architecture": "model",
|
||||||
|
"model.pooling_type": uint32(1),
|
||||||
|
},
|
||||||
|
want: fakeEmbeddingModel{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported",
|
||||||
|
config: fsggml.KV{
|
||||||
|
"general.architecture": "unsupported",
|
||||||
|
},
|
||||||
|
err: ErrUnsupportedModel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := modelForArch(tt.config)
|
||||||
|
if !errors.Is(err, tt.err) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
|
t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type notTextProcessorModel struct{}
|
|
||||||
|
|
||||||
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
|
|
||||||
panic("unimplemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (notTextProcessorModel) Backend() ml.Backend {
|
|
||||||
panic("unimplemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (notTextProcessorModel) Config() config {
|
|
||||||
panic("unimplemented")
|
|
||||||
}
|
|
||||||
|
|||||||
181
model/models/bert/embed.go
Normal file
181
model/models/bert/embed.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package bert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
model.TextProcessor
|
||||||
|
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
|
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||||
|
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||||
|
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
|
||||||
|
|
||||||
|
Layers []EncoderLayer `gguf:"blk"`
|
||||||
|
|
||||||
|
Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward implements model.Model.
|
||||||
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
|
||||||
|
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
|
|
||||||
|
for _, layer := range m.Layers {
|
||||||
|
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||||
|
if m.normalize {
|
||||||
|
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hiddenStates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncoderLayer struct {
|
||||||
|
*Attention
|
||||||
|
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
|
||||||
|
|
||||||
|
*MLP
|
||||||
|
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
// Attention
|
||||||
|
residual := hiddenStates
|
||||||
|
hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
|
||||||
|
// MLP
|
||||||
|
residual = hiddenStates
|
||||||
|
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
|
||||||
|
return hiddenStates
|
||||||
|
}
|
||||||
|
|
||||||
|
type Attention struct {
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
|
||||||
|
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
|
||||||
|
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
batchSize := hiddenStates.Dim(1)
|
||||||
|
|
||||||
|
query := a.Query.Forward(ctx, hiddenStates)
|
||||||
|
if a.QueryNorm != nil {
|
||||||
|
query = a.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
|
}
|
||||||
|
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||||
|
|
||||||
|
key := a.Key.Forward(ctx, hiddenStates)
|
||||||
|
if a.KeyNorm != nil {
|
||||||
|
key = a.KeyNorm.Forward(ctx, key, opts.eps)
|
||||||
|
}
|
||||||
|
key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||||
|
|
||||||
|
value := a.Value.Forward(ctx, hiddenStates)
|
||||||
|
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||||
|
|
||||||
|
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||||
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
return a.Output.Forward(ctx, attention)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MLP struct {
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
hiddenSize,
|
||||||
|
numHeads,
|
||||||
|
numKVHeads,
|
||||||
|
keyLength,
|
||||||
|
valueLength int
|
||||||
|
poolingType pooling.Type
|
||||||
|
eps float32
|
||||||
|
normalize bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Options) headDim() int {
|
||||||
|
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(c fs.Config) (model.Model, error) {
|
||||||
|
var processor model.TextProcessor
|
||||||
|
switch c.String("tokenizer.ggml.model", "bert") {
|
||||||
|
case "bert":
|
||||||
|
processor = model.NewWordPiece(
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{
|
||||||
|
int32(cmp.Or(
|
||||||
|
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||||
|
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||||
|
EOS: []int32{
|
||||||
|
int32(cmp.Or(
|
||||||
|
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||||
|
//nolint:misspell
|
||||||
|
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
|
||||||
|
// support it for compatibility.
|
||||||
|
c.Uint("tokenizer.ggml.seperator_token_id"),
|
||||||
|
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
return nil, model.ErrUnsupportedTokenizer
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Model{
|
||||||
|
TextProcessor: processor,
|
||||||
|
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||||
|
Options: Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
eps: c.Float("attention.layer_norm_epsilon"),
|
||||||
|
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||||
|
normalize: c.Bool("normalize_embeddings", true),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("bert", New)
|
||||||
|
model.Register("bert_embed", New)
|
||||||
|
}
|
||||||
@@ -24,7 +24,7 @@ type Options struct {
|
|||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePieceModel
|
model.SentencePiece
|
||||||
|
|
||||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
Layers []Layer `gguf:"blk"`
|
Layers []Layer `gguf:"blk"`
|
||||||
@@ -40,7 +40,7 @@ const (
|
|||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePiece: model.NewSentencePiece(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
@@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
attnValLen: int(c.Uint("attention.value_length")),
|
attnValLen: int(c.Uint("attention.value_length")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base", 10000.0),
|
ropeBase: c.Float("rope.freq_base", 10000.0),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||||
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
||||||
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
||||||
},
|
},
|
||||||
@@ -88,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
@@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
@@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
@@ -138,7 +138,7 @@ type MLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||||
@@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
var lastLayerOutputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
lastLayerOutputs = outputs
|
lastLayerOutputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||||
|
|||||||
@@ -1,49 +1,38 @@
|
|||||||
package gemma3
|
package gemma3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
type embedModel struct {
|
type embedModel struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePieceModel
|
model.SentencePiece
|
||||||
|
|
||||||
*TextModel
|
*TextModel
|
||||||
PoolingType uint32
|
poolingType pooling.Type
|
||||||
|
|
||||||
Dense [2]*nn.Linear `gguf:"dense"`
|
Dense [2]*nn.Linear `gguf:"dense"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
batch.Outputs = batch.Positions // return all positions
|
|
||||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||||
|
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||||
switch m.PoolingType {
|
|
||||||
case 0: // None
|
|
||||||
case 1: // Mean
|
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
||||||
default:
|
|
||||||
return nil, errors.New("unsupported pooling type")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, dense := range m.Dense {
|
for _, dense := range m.Dense {
|
||||||
hiddenStates = dense.Forward(ctx, hiddenStates)
|
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||||
return hiddenStates, nil
|
return hiddenStates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||||
m := &embedModel{
|
m := &embedModel{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePiece: model.NewSentencePiece(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
@@ -61,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
TextModel: newTextModel(c),
|
TextModel: newTextModel(c),
|
||||||
PoolingType: c.Uint("pooling_type", 0),
|
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Cache = kvcache.NewWrapperCache(
|
m.Cache = kvcache.NewWrapperCache(
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePieceModel
|
model.SentencePiece
|
||||||
|
|
||||||
*VisionModel `gguf:"v"`
|
*VisionModel `gguf:"v"`
|
||||||
*TextModel
|
*TextModel
|
||||||
@@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
|||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePiece: model.NewSentencePiece(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
|||||||
@@ -53,7 +53,10 @@ func newTextModel(c fs.Config) *TextModel {
|
|||||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
ropeScale: 1,
|
||||||
|
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
|
||||||
|
// (8 instead of 1)
|
||||||
|
// ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,7 +87,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
@@ -95,7 +98,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
@@ -113,7 +116,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
|||||||
ropeBase = m.TextConfig.ropeGlobalBase
|
ropeBase = m.TextConfig.ropeGlobalBase
|
||||||
}
|
}
|
||||||
|
|
||||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
@@ -123,7 +126,7 @@ type TextMLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,7 +164,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||||
@@ -194,7 +196,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
|||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
var lastLayerOutputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
lastLayerOutputs = outputs
|
lastLayerOutputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePieceModel
|
model.SentencePiece
|
||||||
|
|
||||||
*TextModel
|
*TextModel
|
||||||
}
|
}
|
||||||
@@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
TextModel: newTextModel(c),
|
TextModel: newTextModel(c),
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePiece: model.NewSentencePiece(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
|||||||
|
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||||
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
|
||||||
|
|
||||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
@@ -95,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
|||||||
ropeBase = m.ropeBaseLocal
|
ropeBase = m.ropeBaseLocal
|
||||||
}
|
}
|
||||||
|
|
||||||
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextScaledWordEmbedding struct {
|
type TextScaledWordEmbedding struct {
|
||||||
@@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
|
|||||||
}
|
}
|
||||||
|
|
||||||
active = d.PerLayerInputGate.Forward(ctx, active)
|
active = d.PerLayerInputGate.Forward(ctx, active)
|
||||||
active = active.GELU(ctx)
|
active = active.GELU(ctx, perLayerInput)
|
||||||
active = active.Mul(ctx, perLayerInput)
|
|
||||||
|
|
||||||
active = d.PerLayerProjection.Forward(ctx, active)
|
active = d.PerLayerProjection.Forward(ctx, active)
|
||||||
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
||||||
@@ -257,14 +256,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
|
|||||||
query := attn.Query.Forward(ctx, hiddenStates)
|
query := attn.Query.Forward(ctx, hiddenStates)
|
||||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||||
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
|
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
var key, value ml.Tensor
|
var key, value ml.Tensor
|
||||||
if !sharedKV {
|
if !sharedKV {
|
||||||
key = attn.Key.Forward(ctx, hiddenStates)
|
key = attn.Key.Forward(ctx, hiddenStates)
|
||||||
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||||
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
|
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
|
||||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
value = attn.Value.Forward(ctx, hiddenStates)
|
value = attn.Value.Forward(ctx, hiddenStates)
|
||||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||||
@@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa
|
|||||||
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
|
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
|
hiddenStates = hiddenStates.GELU(ctx, upStates)
|
||||||
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
|
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
|
||||||
return hiddenStates
|
return hiddenStates
|
||||||
}
|
}
|
||||||
@@ -350,7 +349,7 @@ func newTextModel(c fs.Config) *TextModel {
|
|||||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||||
ropeBase: c.Float("rope.freq_base", 1_000_000),
|
ropeBase: c.Float("rope.freq_base", 1_000_000),
|
||||||
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
|
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||||
|
|
||||||
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||||
activationSparsityScale: c.Floats("activation_sparsity_scale"),
|
activationSparsityScale: c.Floats("activation_sparsity_scale"),
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
|
if i == len(m.TransformerBlocks)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
||||||
@@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *
|
|||||||
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
|
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7)
|
hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7)
|
||||||
|
|
||||||
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||||
experts = experts.Mul(ctx, routingWeights)
|
experts = experts.Mul(ctx, routingWeights)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package llama
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
"fmt"
|
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs"
|
"github.com/ollama/ollama/fs"
|
||||||
@@ -23,51 +22,60 @@ type Options struct {
|
|||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.BytePairEncoding
|
model.TextProcessor
|
||||||
|
|
||||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
Layers []Layer `gguf:"blk"`
|
Layers []Layer `gguf:"blk"`
|
||||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
*Options
|
Options
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
// This model currently only supports the gpt2 tokenizer
|
if c.Uint("expert_count") > 0 {
|
||||||
if c.String("tokenizer.ggml.model") == "llama" {
|
// TODO: support mixtures of experts
|
||||||
return nil, fmt.Errorf("unsupported tokenizer: llama")
|
return nil, model.ErrUnsupportedModel
|
||||||
}
|
}
|
||||||
// Best effort detection of library/deepseek-coder model(s) which are incompatible
|
|
||||||
if c.String("general.name") == "deepseek-ai" {
|
var processor model.TextProcessor
|
||||||
return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
|
vocabulary := model.Vocabulary{
|
||||||
}
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
m := Model{
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
&model.Vocabulary{
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
EOS: append(
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
|
||||||
EOS: append(
|
|
||||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
|
||||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
Layers: make([]Layer, c.Uint("block_count")),
|
}
|
||||||
Options: &Options{
|
switch c.String("tokenizer.ggml.model") {
|
||||||
|
case "gpt2":
|
||||||
|
processor = model.NewBytePairEncoding(
|
||||||
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
|
&vocabulary,
|
||||||
|
)
|
||||||
|
case "llama":
|
||||||
|
processor = model.NewSentencePiece(&vocabulary)
|
||||||
|
default:
|
||||||
|
return nil, model.ErrUnsupportedTokenizer
|
||||||
|
}
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
TextProcessor: processor,
|
||||||
|
Layers: make([]Layer, c.Uint("block_count")),
|
||||||
|
Options: Options{
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
headDim: int(c.Uint("attention.key_length")),
|
headDim: int(c.Uint("attention.key_length")),
|
||||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base", 1e5),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,8 +106,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
|||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||||
@@ -108,7 +116,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
|||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
|
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
@@ -118,7 +126,7 @@ type MLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -160,10 +168,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
|
|||||||
@@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
|
|||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
if useRope {
|
if useRope {
|
||||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.useQKNorm {
|
if opts.useQKNorm {
|
||||||
@@ -58,14 +58,14 @@ type TextMLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextExperts struct {
|
type TextExperts struct {
|
||||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
|||||||
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
|
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
|
||||||
hiddenStates = hiddenStates.Mul(ctx, scores)
|
hiddenStates = hiddenStates.Mul(ctx, scores)
|
||||||
|
|
||||||
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
|
upStates := e.Up.Forward(ctx, hiddenStates, experts)
|
||||||
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
|
gateStates := e.Gate.Forward(ctx, hiddenStates, experts)
|
||||||
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||||
|
|
||||||
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
||||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||||
@@ -96,7 +96,7 @@ type TextSharedExpert struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,7 +196,7 @@ func newTextModel(c fs.Config) *TextModel {
|
|||||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
|
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
|
||||||
noRopeInterval: int(c.Uint("no_rope_interval", 4)),
|
noRopeInterval: int(c.Uint("no_rope_interval", 4)),
|
||||||
@@ -248,5 +248,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
|
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -40,11 +40,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
@@ -65,7 +65,7 @@ type MLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,7 +132,7 @@ func newTextModel(c fs.Config) *TextModel {
|
|||||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ type VisionMLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
// TODO: attention mask, cross attention mask
|
// TODO: attention mask, cross attention mask
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -26,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
|||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||||
|
|
||||||
key := sa.Key.Forward(ctx, hiddenState)
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||||
|
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@@ -45,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
|||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
// This will only get called for layers in the cache, which are just the self attention layers
|
// This will only get called for layers in the cache, which are just the self attention layers
|
||||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
|
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return key, nil
|
return key, nil
|
||||||
@@ -58,7 +58,7 @@ type TextMLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
|
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel {
|
|||||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
|
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
_ "github.com/ollama/ollama/model/models/bert"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||||
|
|||||||
@@ -43,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
|||||||
value := attn.Value.Forward(ctx, hiddenStates)
|
value := attn.Value.Forward(ctx, hiddenStates)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||||
@@ -59,7 +59,7 @@ type MLP struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
||||||
@@ -124,7 +124,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
@@ -160,7 +160,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
headDim: int(c.Uint("attention.key_length")),
|
headDim: int(c.Uint("attention.key_length")),
|
||||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func NewTextModel(c fs.Config) *TextModel {
|
|||||||
originalContextLength: int(c.Uint("context_length", 128000)),
|
originalContextLength: int(c.Uint("context_length", 128000)),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@@ -78,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
|
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MLP implements the feed-forward network component with SwiGLU activation
|
// MLP implements the feed-forward network component with SwiGLU activation
|
||||||
@@ -90,7 +90,7 @@ type MLP struct {
|
|||||||
|
|
||||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
// Apply SwiGLU activation gating
|
// Apply SwiGLU activation gating
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
// Project back to hidden dimension
|
// Project back to hidden dimension
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -100,8 +100,7 @@ type VisionMLP struct {
|
|||||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||||
// Using activation as specified in config (likely GELU or SiLU/Swish)
|
// Using activation as specified in config (likely GELU or SiLU/Swish)
|
||||||
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
|
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
|
||||||
upOutput := mlp.Up.Forward(ctx, hiddenStates)
|
hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
|
|
||||||
|
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|||||||
73
model/models/qwen3/embed.go
Normal file
73
model/models/qwen3/embed.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package qwen3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type embedModel struct {
|
||||||
|
model.Base
|
||||||
|
model.BytePairEncoding
|
||||||
|
|
||||||
|
*Model
|
||||||
|
poolingType pooling.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
hiddenStates, err := m.forward(ctx, batch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||||
|
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||||
|
return hiddenStates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEmbed(c fs.Config) (model.Model, error) {
|
||||||
|
layers := make([]Layer, c.Uint("block_count"))
|
||||||
|
for i := range layers {
|
||||||
|
layers[i].MLP = &dense{}
|
||||||
|
}
|
||||||
|
m := embedModel{
|
||||||
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
EOS: append(
|
||||||
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Model: &Model{
|
||||||
|
Layers: layers,
|
||||||
|
Options: &Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
keyLength: int(c.Uint("attention.key_length")),
|
||||||
|
valueLength: int(c.Uint("attention.value_length")),
|
||||||
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
|
ropeScale: c.Float("rope.freq_scale", 1),
|
||||||
|
numExperts: int(c.Uint("expert_count")),
|
||||||
|
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||||
|
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
@@ -30,10 +30,10 @@ func (o Options) headDim() int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Attention struct {
|
type Attention struct {
|
||||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
|
||||||
Query *nn.Linear `gguf:"attn_q"`
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||||
Key *nn.Linear `gguf:"attn_k"`
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||||
Value *nn.Linear `gguf:"attn_v"`
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
}
|
}
|
||||||
@@ -52,8 +52,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
|||||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||||
|
|
||||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||||
@@ -65,10 +65,10 @@ type MLP interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type sparse struct {
|
type sparse struct {
|
||||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
@@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
|
|||||||
|
|
||||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||||
|
|
||||||
upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts))
|
||||||
|
|
||||||
hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||||
hiddenStates = hiddenStates.SILU(ctx)
|
|
||||||
hiddenStates = hiddenStates.Mul(ctx, upStates)
|
|
||||||
|
|
||||||
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
|
||||||
experts = experts.Mul(ctx, routingWeights)
|
experts = experts.Mul(ctx, routingWeights)
|
||||||
|
|
||||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||||
@@ -111,7 +107,8 @@ type dense struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).
|
||||||
|
SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,29 +151,39 @@ type Model struct {
|
|||||||
*Options
|
*Options
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward implements model.Model.
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
hiddenStates, err := m.forward(ctx, batch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward implements model.Model.
|
||||||
|
func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
|
|
||||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
m.Cache.SetLayer(i)
|
if m.Cache != nil {
|
||||||
|
m.Cache.SetLayer(i)
|
||||||
|
}
|
||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil
|
||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.Model = (*Model)(nil)
|
var _ model.Model = (*Model)(nil)
|
||||||
@@ -216,7 +223,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
valueLength: int(c.Uint("attention.value_length")),
|
valueLength: int(c.Uint("attention.value_length")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
numExperts: int(c.Uint("expert_count")),
|
numExperts: int(c.Uint("expert_count")),
|
||||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||||
normTopKProb: c.Bool("norm_top_k_prob", true),
|
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||||
@@ -230,4 +237,5 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
func init() {
|
func init() {
|
||||||
model.Register("qwen3", New)
|
model.Register("qwen3", New)
|
||||||
model.Register("qwen3moe", New)
|
model.Register("qwen3moe", New)
|
||||||
|
model.Register("qwen3_embed", newEmbed)
|
||||||
}
|
}
|
||||||
|
|||||||
37
model/parsers/parsers.go
Normal file
37
model/parsers/parsers.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Parser interface {
|
||||||
|
Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error)
|
||||||
|
HasToolSupport() bool
|
||||||
|
HasThinkingSupport() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParserForName(name string) Parser {
|
||||||
|
switch name {
|
||||||
|
case "qwen3-coder":
|
||||||
|
parser := &Qwen3CoderParser{}
|
||||||
|
return parser
|
||||||
|
case "passthrough":
|
||||||
|
return &PassthroughParser{}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PassthroughParser struct{}
|
||||||
|
|
||||||
|
func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||||
|
return s, "", nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PassthroughParser) HasToolSupport() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PassthroughParser) HasThinkingSupport() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
447
model/parsers/qwen3coder.go
Normal file
447
model/parsers/qwen3coder.go
Normal file
@@ -0,0 +1,447 @@
|
|||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/xml"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"math"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type qwenParserState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
toolOpenTag = "<tool_call>"
|
||||||
|
toolCloseTag = "</tool_call>"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
qwenParserState_LookingForToolStart qwenParserState = iota
|
||||||
|
qwenParserState_CollectingToolContent
|
||||||
|
)
|
||||||
|
|
||||||
|
type Qwen3CoderParser struct {
|
||||||
|
state qwenParserState
|
||||||
|
acc strings.Builder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||||
|
p.acc.WriteString(s)
|
||||||
|
|
||||||
|
events := p.parseEvents()
|
||||||
|
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, event := range events {
|
||||||
|
switch event := event.(type) {
|
||||||
|
case qwenEventRawToolCall:
|
||||||
|
toolCall, err := parseToolCall(event, tools)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||||
|
return "", "", nil, err
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, toolCall)
|
||||||
|
case qwenEventContent:
|
||||||
|
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||||
|
// events, we naively append them together here. See the note below about
|
||||||
|
// `qwenEvent`s for more details
|
||||||
|
sb.WriteString(event.content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), "", toolCalls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen3CoderParser) parseEvents() []qwenEvent {
|
||||||
|
var all []qwenEvent
|
||||||
|
|
||||||
|
keepLooping := true
|
||||||
|
for keepLooping {
|
||||||
|
var events []qwenEvent
|
||||||
|
events, keepLooping = eat(p)
|
||||||
|
if len(events) > 0 {
|
||||||
|
all = append(all, events...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(all) > 0 {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
|
||||||
|
// we use some internal event types in order to communicate between `Add` and
|
||||||
|
// `eat`. We do this to support interleaving content and parallel tool calls in
|
||||||
|
// the parser, even though qwen3-coder isn't supposed to do this. Our API
|
||||||
|
// doesn't currently support models outputting multiple messages in a turn, so
|
||||||
|
// we wouldn't be able to represent it yet, but there's no reason to prevent the
|
||||||
|
// parser from supporting it, especially for future models if they end up using
|
||||||
|
// a similar format.
|
||||||
|
type qwenEvent interface {
|
||||||
|
isQwenEvent()
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwenEventRawToolCall struct {
|
||||||
|
raw string
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwenEventContent struct {
|
||||||
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qwenEventContent) isQwenEvent() {}
|
||||||
|
func (qwenEventRawToolCall) isQwenEvent() {}
|
||||||
|
|
||||||
|
// eat consumes the parser's buffer, and returns a list of any unambiguous
|
||||||
|
// events from the current parser state. If the parser transitions to another
|
||||||
|
// state, it may have additional events to emit on the next call, which is what
|
||||||
|
// the second return value indicates
|
||||||
|
func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
|
||||||
|
var events []qwenEvent
|
||||||
|
|
||||||
|
switch p.state {
|
||||||
|
case qwenParserState_LookingForToolStart:
|
||||||
|
if strings.Contains(p.acc.String(), toolOpenTag) {
|
||||||
|
// we found a full tool open tag, so we can emit the content before the
|
||||||
|
// tag, being sure to trim any trailing whitespace
|
||||||
|
split := strings.SplitN(p.acc.String(), toolOpenTag, 2)
|
||||||
|
before := split[0]
|
||||||
|
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||||
|
if len(before) > 0 {
|
||||||
|
events = append(events, qwenEventContent{content: before})
|
||||||
|
}
|
||||||
|
after := split[1]
|
||||||
|
p.acc.Reset()
|
||||||
|
p.acc.WriteString(after)
|
||||||
|
p.state = qwenParserState_CollectingToolContent
|
||||||
|
return events, true
|
||||||
|
} else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 {
|
||||||
|
// we found a partial tool open tag, so we can emit the unambiguous part,
|
||||||
|
// which is the (trailing-whitespace trimmed) content before the partial
|
||||||
|
// tool open tag
|
||||||
|
beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap]
|
||||||
|
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||||
|
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||||
|
unambiguous := p.acc.String()[:ambiguousStart]
|
||||||
|
ambiguous := p.acc.String()[ambiguousStart:]
|
||||||
|
p.acc.Reset()
|
||||||
|
p.acc.WriteString(ambiguous)
|
||||||
|
events = append(events, qwenEventContent{content: unambiguous})
|
||||||
|
return events, false
|
||||||
|
} else {
|
||||||
|
// we found content that is entirely not a tool call. We should withhold
|
||||||
|
// any trailing whitespace in case this is the end of the content
|
||||||
|
whitespaceLen := trailingWhitespaceLen(p.acc.String())
|
||||||
|
ambiguousStart := len(p.acc.String()) - whitespaceLen
|
||||||
|
unambiguous := p.acc.String()[:ambiguousStart]
|
||||||
|
ambiguous := p.acc.String()[ambiguousStart:]
|
||||||
|
p.acc.Reset()
|
||||||
|
p.acc.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, qwenEventContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
case qwenParserState_CollectingToolContent:
|
||||||
|
if strings.Contains(p.acc.String(), toolCloseTag) {
|
||||||
|
split := strings.SplitN(p.acc.String(), toolCloseTag, 2)
|
||||||
|
before := split[0]
|
||||||
|
if len(before) == 0 {
|
||||||
|
slog.Warn("qwen tool call closing tag found but no content before it")
|
||||||
|
}
|
||||||
|
// remove any whitespace between the tool call and any content after it
|
||||||
|
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||||
|
p.acc.Reset()
|
||||||
|
p.acc.WriteString(after)
|
||||||
|
events = append(events, qwenEventRawToolCall{raw: before})
|
||||||
|
p.state = qwenParserState_LookingForToolStart
|
||||||
|
return events, true
|
||||||
|
} else {
|
||||||
|
// note that we don't need to check the overlap here because we only plan
|
||||||
|
// on parsing the tool call once we see the full closing tag. We don't
|
||||||
|
// stream back the unparsed tool content, so there's no need to be eager
|
||||||
|
// here
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(drifkin): move this to a shared location
|
||||||
|
// longest overlap between suffix of s and prefix of delim
|
||||||
|
func overlap(s, delim string) int {
|
||||||
|
max := min(len(delim), len(s))
|
||||||
|
for i := max; i > 0; i-- {
|
||||||
|
if strings.HasSuffix(s, delim[:i]) {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func trailingWhitespaceLen(s string) int {
|
||||||
|
for i := len(s) - 1; i >= 0; i-- {
|
||||||
|
if !unicode.IsSpace(rune(s[i])) {
|
||||||
|
return len(s) - i - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLFunctionCall struct {
|
||||||
|
XMLName xml.Name `xml:"function"`
|
||||||
|
Name string `xml:"name,attr"`
|
||||||
|
Parameters []XMLParameter `xml:"parameter"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLParameter struct {
|
||||||
|
Name string `xml:"name,attr"`
|
||||||
|
Value string `xml:",chardata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseToolCall parses a raw tool call string into an api.ToolCall.
|
||||||
|
// The raw string follows an xml-like format, here's an example:
|
||||||
|
//
|
||||||
|
// <function=get_current_temperature>
|
||||||
|
// <parameter=location>
|
||||||
|
// San Francisco
|
||||||
|
// </parameter>
|
||||||
|
// <parameter=unit>
|
||||||
|
// celsius
|
||||||
|
// </parameter>
|
||||||
|
// </function>
|
||||||
|
func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||||
|
toolCall := api.ToolCall{}
|
||||||
|
|
||||||
|
xmlString := transformToXML(raw.raw)
|
||||||
|
|
||||||
|
var functionCall XMLFunctionCall
|
||||||
|
err := xml.Unmarshal([]byte(xmlString), &functionCall)
|
||||||
|
if err != nil {
|
||||||
|
return api.ToolCall{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCall.Function = api.ToolCallFunction{
|
||||||
|
Name: functionCall.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the matching tool to get parameter types
|
||||||
|
var matchedTool *api.Tool
|
||||||
|
for i := range tools {
|
||||||
|
if tools[i].Function.Name == functionCall.Name {
|
||||||
|
matchedTool = &tools[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
|
||||||
|
for _, parameter := range functionCall.Parameters {
|
||||||
|
// Look up the parameter type if we found the tool
|
||||||
|
var paramType api.PropertyType
|
||||||
|
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||||
|
if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
|
||||||
|
paramType = prop.Type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
|
||||||
|
}
|
||||||
|
|
||||||
|
return toolCall, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseValue converts a raw string value to the appropriate type based on the parameter type specification.
|
||||||
|
//
|
||||||
|
// For union types (multiple types in PropertyType, which we support but doesn't
|
||||||
|
// seem as though the reference parser does type coercion with those types in
|
||||||
|
// mind) we use a type precedence approach:
|
||||||
|
// 1. null - checked first regardless of declared types (matches reference implementation)
|
||||||
|
// 2. boolean - only "true"/"false" are valid booleans
|
||||||
|
// 3. integer - must parse as a whole number
|
||||||
|
// 4. number - must parse as numeric (returns int if no decimal part)
|
||||||
|
// 5. array - must parse as valid JSON array
|
||||||
|
// 6. object - must parse as valid JSON object
|
||||||
|
// 7. string - always succeeds (least specific type)
|
||||||
|
//
|
||||||
|
// This precedence ensures we return the most specific type that successfully parses,
|
||||||
|
// following the principle of least surprise. For example, with PropertyType{"string", "number"},
|
||||||
|
// "123" becomes 123 (number), while "hello" becomes "hello" (string).
|
||||||
|
func parseValue(raw string, paramType api.PropertyType) any {
|
||||||
|
// first remove a single leading newlines, and a single trailing newline (if
|
||||||
|
// they exist). This follows the reference implementation
|
||||||
|
raw = strings.TrimPrefix(raw, "\n")
|
||||||
|
raw = strings.TrimSuffix(raw, "\n")
|
||||||
|
|
||||||
|
// Check for null first (case-insensitive) - this takes precedence over any type
|
||||||
|
if strings.ToLower(raw) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no type is specified, default to string
|
||||||
|
if len(paramType) == 0 {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if any of the specified types match, using type precedence
|
||||||
|
// Order: boolean -> integer -> number -> array -> object -> string
|
||||||
|
typeSet := make(map[string]bool)
|
||||||
|
for _, t := range paramType {
|
||||||
|
typeSet[t] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try boolean first (most restrictive)
|
||||||
|
if typeSet["boolean"] {
|
||||||
|
lower := strings.ToLower(raw)
|
||||||
|
switch lower {
|
||||||
|
case "true":
|
||||||
|
return true
|
||||||
|
case "false":
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// If not a valid boolean but boolean is the only type, return false (matching reference)
|
||||||
|
if len(paramType) == 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Otherwise try other types
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try integer
|
||||||
|
if typeSet["integer"] {
|
||||||
|
if i, err := strconv.ParseInt(raw, 10, 64); err == nil {
|
||||||
|
// Return as int if it fits in int32, otherwise int64
|
||||||
|
if i >= math.MinInt32 && i <= math.MaxInt32 {
|
||||||
|
return int(i)
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
// If integer is the only type and parsing failed, fall back to string
|
||||||
|
if len(paramType) == 1 {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try number (float)
|
||||||
|
if typeSet["number"] {
|
||||||
|
if f, err := strconv.ParseFloat(raw, 64); err == nil {
|
||||||
|
// If the number has no decimal part, return as int (matching reference)
|
||||||
|
if f == math.Trunc(f) {
|
||||||
|
i := int64(f)
|
||||||
|
if i >= math.MinInt32 && i <= math.MaxInt32 {
|
||||||
|
return int(i)
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
// If number is the only type and parsing failed, fall back to string
|
||||||
|
if len(paramType) == 1 {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try array
|
||||||
|
if typeSet["array"] {
|
||||||
|
var arr []interface{}
|
||||||
|
if err := json.Unmarshal([]byte(raw), &arr); err == nil {
|
||||||
|
return arr
|
||||||
|
}
|
||||||
|
// If array is the only type and parsing failed, fall back to string
|
||||||
|
if len(paramType) == 1 {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try object
|
||||||
|
if typeSet["object"] {
|
||||||
|
var obj map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(raw), &obj); err == nil {
|
||||||
|
return obj
|
||||||
|
}
|
||||||
|
// If object is the only type and parsing failed, fall back to string
|
||||||
|
if len(paramType) == 1 {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String always succeeds (or if "string" is in the type set)
|
||||||
|
if typeSet["string"] {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we get here, none of the types matched and string wasn't an option
|
||||||
|
// We return string as a fallback. The reference implementation will attempt
|
||||||
|
// to parse the value as a python literal, but we purposefully don't support
|
||||||
|
// that
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
|
||||||
|
qwenXMLTagRegex = regexp.MustCompile(`</?(?:function|parameter)(?:\s+name="[^"]*")?>`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// transformToXML transforms a raw qwen tool call with xml-like tags into valid
|
||||||
|
// xml so that it can be parsed by any xml parser
|
||||||
|
func transformToXML(raw string) string {
|
||||||
|
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
|
||||||
|
// care to properly escape the string that becomes the attribute value
|
||||||
|
transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
|
||||||
|
groups := qwenTagRegex.FindStringSubmatch(match)
|
||||||
|
tag := groups[1]
|
||||||
|
var escapedValue strings.Builder
|
||||||
|
xml.EscapeText(&escapedValue, []byte(groups[2]))
|
||||||
|
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Walk the resulting string, escaping any character data that sits between the
|
||||||
|
// xml tags we just emitted
|
||||||
|
var out strings.Builder
|
||||||
|
lastIdx := 0
|
||||||
|
for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) {
|
||||||
|
if loc[0] > lastIdx {
|
||||||
|
escapeTextNode(&out, transformed[lastIdx:loc[0]])
|
||||||
|
}
|
||||||
|
out.WriteString(transformed[loc[0]:loc[1]])
|
||||||
|
lastIdx = loc[1]
|
||||||
|
}
|
||||||
|
if lastIdx < len(transformed) {
|
||||||
|
escapeTextNode(&out, transformed[lastIdx:])
|
||||||
|
}
|
||||||
|
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// escapeTextNode escapes XML character data without altering other characters
|
||||||
|
// like newlines or tabs (which is why we don't use xml.EscapeText for this)
|
||||||
|
func escapeTextNode(sb *strings.Builder, s string) {
|
||||||
|
for _, r := range s {
|
||||||
|
switch r {
|
||||||
|
case '&':
|
||||||
|
sb.WriteString("&")
|
||||||
|
case '<':
|
||||||
|
sb.WriteString("<")
|
||||||
|
case '>':
|
||||||
|
sb.WriteString(">")
|
||||||
|
default:
|
||||||
|
sb.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
878
model/parsers/qwen3coder_test.go
Normal file
878
model/parsers/qwen3coder_test.go
Normal file
@@ -0,0 +1,878 @@
|
|||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// tool creates a test tool with the given name and properties
|
||||||
|
func tool(name string, props map[string]api.ToolProperty) api.Tool {
|
||||||
|
t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
|
||||||
|
t.Function.Parameters.Type = "object"
|
||||||
|
t.Function.Parameters.Properties = props
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwenParserStreaming(t *testing.T) {
|
||||||
|
type step struct {
|
||||||
|
input string
|
||||||
|
wantEvents []qwenEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
desc string
|
||||||
|
steps []step
|
||||||
|
only bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "simple message streamed word by word",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "hi",
|
||||||
|
wantEvents: []qwenEvent{qwenEventContent{content: "hi"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: " there",
|
||||||
|
wantEvents: []qwenEvent{qwenEventContent{content: " there"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "content before tool call",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "hi there<tool_call>",
|
||||||
|
wantEvents: []qwenEvent{qwenEventContent{content: "hi there"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "multiple tool calls in one message",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "before1<tool_call>in tool call</tool_call>after1<tool_call>in tool call 2</tool_call>after2",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "before1"},
|
||||||
|
qwenEventRawToolCall{raw: "in tool call"},
|
||||||
|
qwenEventContent{content: "after1"},
|
||||||
|
qwenEventRawToolCall{raw: "in tool call 2"},
|
||||||
|
qwenEventContent{content: "after2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "tool calls with split tags",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "before<tool",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "before"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "_call>in tool call</tool",
|
||||||
|
wantEvents: []qwenEvent{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "_call>af",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventRawToolCall{raw: "in tool call"},
|
||||||
|
qwenEventContent{content: "af"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "ter",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "ter"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "trailing whitespace between content and tool call",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "abc\n<tool_call>def</tool_call>",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "abc"},
|
||||||
|
qwenEventRawToolCall{raw: "def"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "trailing whitespace between tool call and content",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "<tool_call>abc</tool_call>\ndef",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventRawToolCall{raw: "abc"},
|
||||||
|
qwenEventContent{content: "def"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "empty content before tool call",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "\n<tool_call>abc</tool_call>",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventRawToolCall{raw: "abc"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "partial tool open tag fakeout",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "abc\n<tool_call",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
// \n should not be emitted yet because `<tool_call` might be a tool
|
||||||
|
// open tag, in which case the whitespace should be trimmed
|
||||||
|
qwenEventContent{content: "abc"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: " fakeout",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "\n<tool_call fakeout"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "token-by-token whitespace handling",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "a",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "a"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "\n",
|
||||||
|
wantEvents: []qwenEvent{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "b",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "\nb"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
anyOnlies := false
|
||||||
|
for _, tc := range cases {
|
||||||
|
if tc.only {
|
||||||
|
anyOnlies = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
if anyOnlies && !tc.only {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
|
parser := Qwen3CoderParser{}
|
||||||
|
|
||||||
|
for i, step := range tc.steps {
|
||||||
|
parser.acc.WriteString(step.input)
|
||||||
|
gotEvents := parser.parseEvents()
|
||||||
|
|
||||||
|
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||||
|
// avoid deep equal on empty vs. nil slices
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||||
|
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwenToolParser(t *testing.T) {
|
||||||
|
type step struct {
|
||||||
|
name string
|
||||||
|
rawToolCall string
|
||||||
|
tools []api.Tool
|
||||||
|
wantToolCall api.ToolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
steps := []step{
|
||||||
|
{
|
||||||
|
name: "simple tool call",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `<function=get_current_temperature>
|
||||||
|
<parameter=location>
|
||||||
|
San Francisco
|
||||||
|
</parameter>
|
||||||
|
<parameter=unit>
|
||||||
|
celsius
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_temperature",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "San Francisco",
|
||||||
|
"unit": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "names with spaces",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `<function=get current temperature>
|
||||||
|
<parameter=location with spaces>
|
||||||
|
San Francisco
|
||||||
|
</parameter>
|
||||||
|
<parameter=unit with spaces>
|
||||||
|
celsius
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get current temperature",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location with spaces": "San Francisco",
|
||||||
|
"unit with spaces": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// this mirrors the reference implementation's behavior, but unclear if it
|
||||||
|
// ever happens. If so, then we should probably remove them instead, this
|
||||||
|
// test is to just document the current behavior and test that we don't get
|
||||||
|
// xml errors
|
||||||
|
{
|
||||||
|
name: "names with quotes",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `<function="get current temperature">
|
||||||
|
<parameter="location with spaces">
|
||||||
|
San Francisco
|
||||||
|
</parameter>
|
||||||
|
<parameter="unit with spaces">
|
||||||
|
"celsius"
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "\"get current temperature\"",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"\"location with spaces\"": "San Francisco",
|
||||||
|
"\"unit with spaces\"": "\"celsius\"",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with typed parameters",
|
||||||
|
tools: []api.Tool{
|
||||||
|
tool("calculate", map[string]api.ToolProperty{
|
||||||
|
"x": {Type: api.PropertyType{"number"}},
|
||||||
|
"y": {Type: api.PropertyType{"integer"}},
|
||||||
|
"enabled": {Type: api.PropertyType{"boolean"}},
|
||||||
|
"items": {Type: api.PropertyType{"array"}},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
rawToolCall: `<function=calculate>
|
||||||
|
<parameter=x>
|
||||||
|
3.14
|
||||||
|
</parameter>
|
||||||
|
<parameter=y>
|
||||||
|
42
|
||||||
|
</parameter>
|
||||||
|
<parameter=enabled>
|
||||||
|
true
|
||||||
|
</parameter>
|
||||||
|
<parameter=items>
|
||||||
|
["a", "b", "c"]
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "calculate",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"x": 3.14,
|
||||||
|
"y": 42,
|
||||||
|
"enabled": true,
|
||||||
|
"items": []any{"a", "b", "c"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// regression test for <https://github.com/ollama/ollama/issues/12357>
|
||||||
|
{
|
||||||
|
name: "ampersands in parameter values",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `<function=exec>
|
||||||
|
<parameter=command>
|
||||||
|
ls && echo "done"
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "exec",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"command": "ls && echo \"done\"",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "angle brackets in parameter values",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `<function=exec>
|
||||||
|
<parameter=command>
|
||||||
|
ls && echo "a > b and a < b"
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "exec",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"command": "ls && echo \"a > b and a < b\"",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, step := range steps {
|
||||||
|
gotToolCall, err := parseToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("step %d (%s): %v", i, step.name, err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
|
||||||
|
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwenToolCallValueParsing(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
desc string
|
||||||
|
raw string
|
||||||
|
paramType api.PropertyType
|
||||||
|
want any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "default string value (no type specified)",
|
||||||
|
paramType: api.PropertyType{},
|
||||||
|
raw: "some-string",
|
||||||
|
want: "some-string",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "trim a single leading and trailing newline",
|
||||||
|
paramType: api.PropertyType{},
|
||||||
|
raw: "\nsome-string\n",
|
||||||
|
want: "some-string",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "trim at most one leading and trailing newline",
|
||||||
|
paramType: api.PropertyType{},
|
||||||
|
raw: "\n\nsome-string\n\n",
|
||||||
|
want: "\nsome-string\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "newline really has to be the first character to be trimmed",
|
||||||
|
paramType: api.PropertyType{},
|
||||||
|
raw: " \nsome-string\n ",
|
||||||
|
want: " \nsome-string\n ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "numeric type",
|
||||||
|
paramType: api.PropertyType{"number"},
|
||||||
|
raw: "123",
|
||||||
|
want: 123,
|
||||||
|
},
|
||||||
|
// Integer parsing tests
|
||||||
|
{
|
||||||
|
desc: "integer type",
|
||||||
|
paramType: api.PropertyType{"integer"},
|
||||||
|
raw: "42",
|
||||||
|
want: 42,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "negative integer",
|
||||||
|
paramType: api.PropertyType{"integer"},
|
||||||
|
raw: "-100",
|
||||||
|
want: -100,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "zero integer",
|
||||||
|
paramType: api.PropertyType{"integer"},
|
||||||
|
raw: "0",
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "integer with leading zeros",
|
||||||
|
paramType: api.PropertyType{"integer"},
|
||||||
|
raw: "007",
|
||||||
|
want: 7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "large integer",
|
||||||
|
paramType: api.PropertyType{"integer"},
|
||||||
|
raw: "2147483648", // Just beyond int32 max
|
||||||
|
want: int64(2147483648),
|
||||||
|
},
|
||||||
|
// Float/number parsing tests
|
||||||
|
{
|
||||||
|
desc: "float type",
|
||||||
|
paramType: api.PropertyType{"number"},
|
||||||
|
raw: "3.14",
|
||||||
|
want: 3.14,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "negative float",
|
||||||
|
paramType: api.PropertyType{"number"},
|
||||||
|
raw: "-273.15",
|
||||||
|
want: -273.15,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "float without decimal part",
|
||||||
|
paramType: api.PropertyType{"number"},
|
||||||
|
raw: "100.0",
|
||||||
|
want: 100,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "scientific notation positive",
|
||||||
|
paramType: api.PropertyType{"number"},
|
||||||
|
raw: "1.23e5",
|
||||||
|
want: 123000, // Will be int since it has no decimal part
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "scientific notation negative",
|
||||||
|
paramType: api.PropertyType{"number"},
|
||||||
|
raw: "1.5e-3",
|
||||||
|
want: 0.0015,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "very small float",
|
||||||
|
paramType: api.PropertyType{"number"},
|
||||||
|
raw: "0.00000001",
|
||||||
|
want: 0.00000001,
|
||||||
|
},
|
||||||
|
// String parsing tests
|
||||||
|
{
|
||||||
|
desc: "explicit string type",
|
||||||
|
paramType: api.PropertyType{"string"},
|
||||||
|
raw: "hello world",
|
||||||
|
want: "hello world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "string with special characters",
|
||||||
|
paramType: api.PropertyType{"string"},
|
||||||
|
raw: "/usr/local/bin/test-file_v2.0.sh",
|
||||||
|
want: "/usr/local/bin/test-file_v2.0.sh",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "string with quotes",
|
||||||
|
paramType: api.PropertyType{"string"},
|
||||||
|
raw: `He said "hello" to me`,
|
||||||
|
want: `He said "hello" to me`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "multiline string",
|
||||||
|
paramType: api.PropertyType{"string"},
|
||||||
|
raw: "line one\nline two\nline three",
|
||||||
|
want: "line one\nline two\nline three",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "empty string",
|
||||||
|
paramType: api.PropertyType{"string"},
|
||||||
|
raw: "",
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "string that looks like a number",
|
||||||
|
paramType: api.PropertyType{"string"},
|
||||||
|
raw: "12345",
|
||||||
|
want: "12345",
|
||||||
|
},
|
||||||
|
// Boolean parsing tests
|
||||||
|
{
|
||||||
|
desc: "boolean true",
|
||||||
|
paramType: api.PropertyType{"boolean"},
|
||||||
|
raw: "true",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "boolean false",
|
||||||
|
paramType: api.PropertyType{"boolean"},
|
||||||
|
raw: "false",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "boolean case insensitive true",
|
||||||
|
paramType: api.PropertyType{"boolean"},
|
||||||
|
raw: "True",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "boolean case insensitive false",
|
||||||
|
paramType: api.PropertyType{"boolean"},
|
||||||
|
raw: "FALSE",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
// Null parsing tests
|
||||||
|
{
|
||||||
|
desc: "null value lowercase",
|
||||||
|
paramType: api.PropertyType{"string"},
|
||||||
|
raw: "null",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "null value case insensitive",
|
||||||
|
paramType: api.PropertyType{"integer"},
|
||||||
|
raw: "NULL",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
// Array parsing tests
|
||||||
|
{
|
||||||
|
desc: "array of strings",
|
||||||
|
paramType: api.PropertyType{"array"},
|
||||||
|
raw: `["foo", "bar", "baz"]`,
|
||||||
|
want: []any{"foo", "bar", "baz"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "array of numbers",
|
||||||
|
paramType: api.PropertyType{"array"},
|
||||||
|
raw: `[1, 2.5, 3]`,
|
||||||
|
want: []any{float64(1), 2.5, float64(3)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "array of mixed types",
|
||||||
|
paramType: api.PropertyType{"array"},
|
||||||
|
raw: `["string", 123, true, null]`,
|
||||||
|
want: []any{"string", float64(123), true, nil},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "empty array",
|
||||||
|
paramType: api.PropertyType{"array"},
|
||||||
|
raw: `[]`,
|
||||||
|
want: []any{},
|
||||||
|
},
|
||||||
|
// Object parsing tests
|
||||||
|
{
|
||||||
|
desc: "simple object",
|
||||||
|
paramType: api.PropertyType{"object"},
|
||||||
|
raw: `{"key": "value", "number": 42}`,
|
||||||
|
want: map[string]any{"key": "value", "number": float64(42)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "nested object",
|
||||||
|
paramType: api.PropertyType{"object"},
|
||||||
|
raw: `{"outer": {"inner": "value"}}`,
|
||||||
|
want: map[string]any{"outer": map[string]any{"inner": "value"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "empty object",
|
||||||
|
paramType: api.PropertyType{"object"},
|
||||||
|
raw: `{}`,
|
||||||
|
want: map[string]any{},
|
||||||
|
},
|
||||||
|
// Error cases and fallback behavior
|
||||||
|
{
|
||||||
|
desc: "invalid integer falls back to string",
|
||||||
|
paramType: api.PropertyType{"integer"},
|
||||||
|
raw: "not-a-number",
|
||||||
|
want: "not-a-number",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "invalid float falls back to string",
|
||||||
|
paramType: api.PropertyType{"number"},
|
||||||
|
raw: "3.14.159",
|
||||||
|
want: "3.14.159",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "invalid boolean falls back to false",
|
||||||
|
paramType: api.PropertyType{"boolean"},
|
||||||
|
raw: "yes",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "invalid JSON array falls back to string",
|
||||||
|
paramType: api.PropertyType{"array"},
|
||||||
|
raw: "[1, 2, unclosed",
|
||||||
|
want: "[1, 2, unclosed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "invalid JSON object falls back to string",
|
||||||
|
paramType: api.PropertyType{"object"},
|
||||||
|
raw: `{"key": unclosed`,
|
||||||
|
want: `{"key": unclosed`,
|
||||||
|
},
|
||||||
|
// Edge cases
|
||||||
|
{
|
||||||
|
desc: "integer overflow should use int64",
|
||||||
|
paramType: api.PropertyType{"integer"},
|
||||||
|
raw: "2147483648", // Beyond int32 max
|
||||||
|
want: int64(2147483648),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "float with many decimal places",
|
||||||
|
paramType: api.PropertyType{"number"},
|
||||||
|
raw: "3.141592653589793",
|
||||||
|
want: 3.141592653589793,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "string with JSON-like content",
|
||||||
|
paramType: api.PropertyType{"string"},
|
||||||
|
raw: `{"this": "is", "just": "a string"}`,
|
||||||
|
want: `{"this": "is", "just": "a string"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "whitespace-only string",
|
||||||
|
paramType: api.PropertyType{"string"},
|
||||||
|
raw: " ",
|
||||||
|
want: " ",
|
||||||
|
},
|
||||||
|
// Unknown parameter (no type specified in tools)
|
||||||
|
{
|
||||||
|
desc: "parameter not in tool definition defaults to string",
|
||||||
|
paramType: api.PropertyType{},
|
||||||
|
raw: "some value",
|
||||||
|
want: "some value",
|
||||||
|
},
|
||||||
|
// Union type tests
|
||||||
|
{
|
||||||
|
desc: "string or number union - valid number",
|
||||||
|
paramType: api.PropertyType{"string", "number"},
|
||||||
|
raw: "42.5",
|
||||||
|
want: 42.5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "string or number union - non-numeric string",
|
||||||
|
paramType: api.PropertyType{"string", "number"},
|
||||||
|
raw: "hello",
|
||||||
|
want: "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "number or string union - valid number (order shouldn't matter)",
|
||||||
|
paramType: api.PropertyType{"number", "string"},
|
||||||
|
raw: "42.5",
|
||||||
|
want: 42.5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "integer or null union - valid integer",
|
||||||
|
paramType: api.PropertyType{"integer", "null"},
|
||||||
|
raw: "123",
|
||||||
|
want: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "integer or null union - null value",
|
||||||
|
paramType: api.PropertyType{"integer", "null"},
|
||||||
|
raw: "null",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "null or integer union - null value (order shouldn't matter)",
|
||||||
|
paramType: api.PropertyType{"null", "integer"},
|
||||||
|
raw: "null",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "boolean or string union - valid boolean",
|
||||||
|
paramType: api.PropertyType{"boolean", "string"},
|
||||||
|
raw: "true",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "boolean or string union - non-boolean becomes string",
|
||||||
|
paramType: api.PropertyType{"boolean", "string"},
|
||||||
|
raw: "yes",
|
||||||
|
want: "yes",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "string or boolean union - valid boolean (precedence test)",
|
||||||
|
paramType: api.PropertyType{"string", "boolean"},
|
||||||
|
raw: "false",
|
||||||
|
want: false, // Should be boolean, not string "false"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "integer or number union - integer value",
|
||||||
|
paramType: api.PropertyType{"integer", "number"},
|
||||||
|
raw: "42",
|
||||||
|
want: 42,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "integer or number union - float value",
|
||||||
|
paramType: api.PropertyType{"integer", "number"},
|
||||||
|
raw: "42.5",
|
||||||
|
want: 42.5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "number or integer union - integer value (precedence test)",
|
||||||
|
paramType: api.PropertyType{"number", "integer"},
|
||||||
|
raw: "42",
|
||||||
|
want: 42, // Should try integer first due to precedence
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "array or object union - valid array",
|
||||||
|
paramType: api.PropertyType{"array", "object"},
|
||||||
|
raw: `[1, 2, 3]`,
|
||||||
|
want: []any{float64(1), float64(2), float64(3)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "array or object union - valid object",
|
||||||
|
paramType: api.PropertyType{"array", "object"},
|
||||||
|
raw: `{"key": "value"}`,
|
||||||
|
want: map[string]any{"key": "value"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "object or array union - valid array (precedence test)",
|
||||||
|
paramType: api.PropertyType{"object", "array"},
|
||||||
|
raw: `[1, 2, 3]`,
|
||||||
|
want: []any{float64(1), float64(2), float64(3)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "complex multi-type union - null",
|
||||||
|
paramType: api.PropertyType{"string", "number", "boolean", "null"},
|
||||||
|
raw: "null",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "complex multi-type union - boolean",
|
||||||
|
paramType: api.PropertyType{"string", "number", "boolean", "null"},
|
||||||
|
raw: "true",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "complex multi-type union - number",
|
||||||
|
paramType: api.PropertyType{"string", "number", "boolean", "null"},
|
||||||
|
raw: "3.14",
|
||||||
|
want: 3.14,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "complex multi-type union - string",
|
||||||
|
paramType: api.PropertyType{"string", "number", "boolean", "null"},
|
||||||
|
raw: "hello",
|
||||||
|
want: "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "integer string union - integer string becomes integer",
|
||||||
|
paramType: api.PropertyType{"integer", "string"},
|
||||||
|
raw: "123",
|
||||||
|
want: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "string integer union - integer string becomes integer (precedence)",
|
||||||
|
paramType: api.PropertyType{"string", "integer"},
|
||||||
|
raw: "123",
|
||||||
|
want: 123, // Integer has higher precedence than string
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
|
got := parseValue(tc.raw, tc.paramType)
|
||||||
|
if !reflect.DeepEqual(got, tc.want) {
|
||||||
|
t.Errorf("got %v (type %T), want %v (type %T)", got, got, tc.want, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwenXMLTransform(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
desc string
|
||||||
|
raw string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "simple example",
|
||||||
|
raw: `<function=get_current_temperature>
|
||||||
|
<parameter=location>
|
||||||
|
San Francisco
|
||||||
|
</parameter>
|
||||||
|
<parameter=unit>
|
||||||
|
celsius
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
want: `<function name="get_current_temperature">
|
||||||
|
<parameter name="location">
|
||||||
|
San Francisco
|
||||||
|
</parameter>
|
||||||
|
<parameter name="unit">
|
||||||
|
celsius
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
},
|
||||||
|
// even though quotes aren't expected in these tags, we have these tests to
|
||||||
|
// make sure they're escaped so they don't blow up the xml parser in case
|
||||||
|
// they happen
|
||||||
|
{
|
||||||
|
desc: "names with quotes",
|
||||||
|
raw: `<function="get current temperature">
|
||||||
|
<parameter="location with spaces">
|
||||||
|
San Francisco
|
||||||
|
</parameter>
|
||||||
|
<parameter="unit with spaces">
|
||||||
|
celsius
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
want: `<function name=""get current temperature"">
|
||||||
|
<parameter name=""location with spaces"">
|
||||||
|
San Francisco
|
||||||
|
</parameter>
|
||||||
|
<parameter name=""unit with spaces"">
|
||||||
|
celsius
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "ampersands in parameter values",
|
||||||
|
raw: `<function=get_current_temperature>
|
||||||
|
<parameter=location>
|
||||||
|
San Francisco & San Jose
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
want: `<function name="get_current_temperature">
|
||||||
|
<parameter name="location">
|
||||||
|
San Francisco & San Jose
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
got := transformToXML(tc.raw)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("got %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTrailingWhitespaceLen(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
desc string
|
||||||
|
s string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{desc: "no whitespace", s: "abc", want: 0},
|
||||||
|
{desc: "trailing whitespace", s: "abc ", want: 1},
|
||||||
|
{desc: "trailing whitespace with newlines", s: "abc \n", want: 2},
|
||||||
|
{desc: "only whitespace", s: " \n ", want: 4},
|
||||||
|
{desc: "leading whitespace doesn't count", s: " \n abc", want: 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
got := trailingWhitespaceLen(tc.s)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("got %d, want %d", got, tc.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
217
model/renderers/qwen3coder.go
Normal file
217
model/renderers/qwen3coder.go
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
imStartTag = "<|im_start|>"
|
||||||
|
imEndTag = "<|im_end|>"
|
||||||
|
)
|
||||||
|
|
||||||
|
// renderAdditionalKeys renders all JSON fields except the ones in handledKeys
|
||||||
|
// This follows the same approach from the reference implementation, which gives
|
||||||
|
// a particular key ordering
|
||||||
|
func renderAdditionalKeys(obj any, handledKeys map[string]bool) string {
|
||||||
|
data, err := json.Marshal(obj)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var m map[string]any
|
||||||
|
if err := json.Unmarshal(data, &m); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
for key, value := range m {
|
||||||
|
if handledKeys[key] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if value is a map or array (needs JSON serialization)
|
||||||
|
switch v := value.(type) {
|
||||||
|
case map[string]any, []any:
|
||||||
|
jsonBytes, _ := json.Marshal(v)
|
||||||
|
// TODO(drifkin): it would be nice to format the JSON here similarly to
|
||||||
|
// python's default json.dumps behavior (spaces after commas and colons).
|
||||||
|
// This would let us be byte-for-byte compatible with the reference
|
||||||
|
// implementation for most common inputs
|
||||||
|
jsonStr := string(jsonBytes)
|
||||||
|
sb.WriteString("\n<" + key + ">" + jsonStr + "</" + key + ">")
|
||||||
|
case nil:
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
// Simple types, convert to string
|
||||||
|
sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "</" + key + ">")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
// filter out system messages and choose the first (if any) to win
|
||||||
|
var systemMessage string
|
||||||
|
var filteredMessages []api.Message
|
||||||
|
for _, message := range messages {
|
||||||
|
if message.Role != "system" {
|
||||||
|
filteredMessages = append(filteredMessages, message)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemMessage == "" {
|
||||||
|
systemMessage = message.Content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemMessage != "" || len(tools) > 0 {
|
||||||
|
sb.WriteString(imStartTag + "system\n")
|
||||||
|
|
||||||
|
// if we have tools but no system message, match the reference implementation by providing a default system message
|
||||||
|
if systemMessage == "" {
|
||||||
|
systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks."
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(systemMessage)
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n")
|
||||||
|
sb.WriteString("<tools>")
|
||||||
|
for _, tool := range tools {
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString("<function>\n")
|
||||||
|
sb.WriteString("<name>" + tool.Function.Name + "</name>")
|
||||||
|
if tool.Function.Description != "" {
|
||||||
|
sb.WriteString("\n<description>" + tool.Function.Description + "</description>")
|
||||||
|
}
|
||||||
|
sb.WriteString("\n<parameters>")
|
||||||
|
|
||||||
|
for name, prop := range tool.Function.Parameters.Properties {
|
||||||
|
sb.WriteString("\n<parameter>")
|
||||||
|
sb.WriteString("\n<name>" + name + "</name>")
|
||||||
|
|
||||||
|
if len(prop.Type) > 0 {
|
||||||
|
// TODO(!!!)(drifkin): we should match the reference implementation for
|
||||||
|
// more complex types here instead of using this format
|
||||||
|
sb.WriteString("\n<type>" + prop.ToTypeScriptType() + "</type>")
|
||||||
|
}
|
||||||
|
|
||||||
|
if prop.Description != "" {
|
||||||
|
sb.WriteString("\n<description>" + prop.Description + "</description>")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render any additional keys not already handled
|
||||||
|
handledKeys := map[string]bool{
|
||||||
|
"type": true,
|
||||||
|
"description": true,
|
||||||
|
}
|
||||||
|
sb.WriteString(renderAdditionalKeys(prop, handledKeys))
|
||||||
|
|
||||||
|
sb.WriteString("\n</parameter>")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render extra keys for parameters (everything except 'type' and 'properties')
|
||||||
|
paramHandledKeys := map[string]bool{
|
||||||
|
"type": true,
|
||||||
|
"properties": true,
|
||||||
|
}
|
||||||
|
sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys))
|
||||||
|
|
||||||
|
sb.WriteString("\n</parameters>")
|
||||||
|
sb.WriteString("\n</function>")
|
||||||
|
}
|
||||||
|
sb.WriteString("\n</tools>")
|
||||||
|
sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(imEndTag + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, message := range filteredMessages {
|
||||||
|
lastMessage := i == len(filteredMessages)-1
|
||||||
|
prefill := lastMessage && message.Role == "assistant"
|
||||||
|
switch message.Role {
|
||||||
|
case "assistant":
|
||||||
|
if len(message.ToolCalls) > 0 {
|
||||||
|
sb.WriteString(imStartTag + "assistant\n")
|
||||||
|
if message.Content != "" {
|
||||||
|
sb.WriteString(message.Content + "\n")
|
||||||
|
}
|
||||||
|
for _, toolCall := range message.ToolCalls {
|
||||||
|
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
|
||||||
|
for name, value := range toolCall.Function.Arguments {
|
||||||
|
valueStr := formatToolCallArgument(value)
|
||||||
|
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
|
||||||
|
}
|
||||||
|
sb.WriteString("\n</function>\n</tool_call>")
|
||||||
|
}
|
||||||
|
sb.WriteString("<|im_end|>\n")
|
||||||
|
} else {
|
||||||
|
sb.WriteString(imStartTag + "assistant\n")
|
||||||
|
sb.WriteString(message.Content)
|
||||||
|
if !prefill {
|
||||||
|
sb.WriteString(imEndTag + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "tool":
|
||||||
|
// consecutive tool responses should share a single `<im_start>user`, but
|
||||||
|
// have their own <tool_response> tags
|
||||||
|
|
||||||
|
// only start a new user block if this is the first tool response
|
||||||
|
if i == 0 || filteredMessages[i-1].Role != "tool" {
|
||||||
|
sb.WriteString(imStartTag + "user\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString("<tool_response>\n")
|
||||||
|
sb.WriteString(message.Content)
|
||||||
|
sb.WriteString("\n</tool_response>\n")
|
||||||
|
|
||||||
|
// close the user block only if this is the last tool response
|
||||||
|
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
|
||||||
|
sb.WriteString(imEndTag + "\n")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
sb.WriteString(imStartTag + message.Role + "\n")
|
||||||
|
sb.WriteString(message.Content)
|
||||||
|
sb.WriteString(imEndTag + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastMessage && !prefill {
|
||||||
|
sb.WriteString(imStartTag + "assistant\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatToolCallArgument(value any) string {
|
||||||
|
if value == nil {
|
||||||
|
return "null"
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
return v
|
||||||
|
case []byte:
|
||||||
|
return string(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflect.TypeOf(value) != nil {
|
||||||
|
kind := reflect.TypeOf(value).Kind()
|
||||||
|
if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array {
|
||||||
|
if marshalled, err := json.Marshal(value); err == nil {
|
||||||
|
return string(marshalled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%v", value)
|
||||||
|
}
|
||||||
338
model/renderers/qwen3coder_test.go
Normal file
338
model/renderers/qwen3coder_test.go
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQwen3CoderRenderer(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
msgs []api.Message
|
||||||
|
tools []api.Tool
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "Hello, how are you?"},
|
||||||
|
},
|
||||||
|
expected: `<|im_start|>system
|
||||||
|
You are a helpful assistant.<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
Hello, how are you?<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with tools and response",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||||
|
{Role: "user", Content: "What is the weather like in San Francisco?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "I'll check the weather in San Francisco for you.",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"},
|
||||||
|
{Role: "user", Content: "That sounds nice! What about New York?"},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather in a given location",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Required: []string{"unit"},
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
|
||||||
|
// TODO(drifkin): add multiple params back once we have predictable
|
||||||
|
// order via some sort of ordered map type (see
|
||||||
|
// <https://github.com/ollama/ollama/issues/12244>)
|
||||||
|
/*
|
||||||
|
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
|
||||||
|
*/
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
expected: `<|im_start|>system
|
||||||
|
You are a helpful assistant with access to tools.
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
|
||||||
|
You have access to the following functions:
|
||||||
|
|
||||||
|
<tools>
|
||||||
|
<function>
|
||||||
|
<name>get_weather</name>
|
||||||
|
<description>Get the current weather in a given location</description>
|
||||||
|
<parameters>
|
||||||
|
<parameter>
|
||||||
|
<name>unit</name>
|
||||||
|
<type>string</type>
|
||||||
|
<description>The unit of temperature</description>
|
||||||
|
<enum>["celsius","fahrenheit"]</enum>
|
||||||
|
</parameter>
|
||||||
|
<required>["unit"]</required>
|
||||||
|
</parameters>
|
||||||
|
</function>
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
<function=example_function_name>
|
||||||
|
<parameter=example_parameter_1>
|
||||||
|
value_1
|
||||||
|
</parameter>
|
||||||
|
<parameter=example_parameter_2>
|
||||||
|
This is the value for the second parameter
|
||||||
|
that can span
|
||||||
|
multiple lines
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>
|
||||||
|
|
||||||
|
<IMPORTANT>
|
||||||
|
Reminder:
|
||||||
|
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||||
|
- Required parameters MUST be specified
|
||||||
|
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||||
|
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||||
|
</IMPORTANT><|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
What is the weather like in San Francisco?<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
I'll check the weather in San Francisco for you.
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
<function=get_weather>
|
||||||
|
<parameter=unit>
|
||||||
|
fahrenheit
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call><|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
<tool_response>
|
||||||
|
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
|
||||||
|
</tool_response>
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
That sounds nice! What about New York?<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parallel tool calls",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||||
|
{Role: "user", Content: "call double(1) and triple(2)"},
|
||||||
|
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
|
||||||
|
{Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
|
||||||
|
}},
|
||||||
|
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
|
||||||
|
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
||||||
|
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
|
||||||
|
}}}},
|
||||||
|
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
||||||
|
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
|
||||||
|
}}}},
|
||||||
|
},
|
||||||
|
expected: `<|im_start|>system
|
||||||
|
You are a helpful assistant with access to tools.
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
|
||||||
|
You have access to the following functions:
|
||||||
|
|
||||||
|
<tools>
|
||||||
|
<function>
|
||||||
|
<name>double</name>
|
||||||
|
<description>Double a number</description>
|
||||||
|
<parameters>
|
||||||
|
<parameter>
|
||||||
|
<name>number</name>
|
||||||
|
<type>string</type>
|
||||||
|
<description>The number to double</description>
|
||||||
|
</parameter>
|
||||||
|
</parameters>
|
||||||
|
</function>
|
||||||
|
<function>
|
||||||
|
<name>triple</name>
|
||||||
|
<description>Triple a number</description>
|
||||||
|
<parameters>
|
||||||
|
<parameter>
|
||||||
|
<name>number</name>
|
||||||
|
<type>string</type>
|
||||||
|
<description>The number to triple</description>
|
||||||
|
</parameter>
|
||||||
|
</parameters>
|
||||||
|
</function>
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
<function=example_function_name>
|
||||||
|
<parameter=example_parameter_1>
|
||||||
|
value_1
|
||||||
|
</parameter>
|
||||||
|
<parameter=example_parameter_2>
|
||||||
|
This is the value for the second parameter
|
||||||
|
that can span
|
||||||
|
multiple lines
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>
|
||||||
|
|
||||||
|
<IMPORTANT>
|
||||||
|
Reminder:
|
||||||
|
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||||
|
- Required parameters MUST be specified
|
||||||
|
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||||
|
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||||
|
</IMPORTANT><|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
call double(1) and triple(2)<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
I'll call double(1) and triple(2) for you.
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
<function=double>
|
||||||
|
<parameter=number>
|
||||||
|
1
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>
|
||||||
|
<tool_call>
|
||||||
|
<function=triple>
|
||||||
|
<parameter=number>
|
||||||
|
2
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call><|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
<tool_response>
|
||||||
|
{"number": 2}
|
||||||
|
</tool_response>
|
||||||
|
<tool_response>
|
||||||
|
{"number": 6}
|
||||||
|
</tool_response>
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "prefill",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "Tell me something interesting."},
|
||||||
|
{Role: "assistant", Content: "I'll tell you something interesting about cats"},
|
||||||
|
},
|
||||||
|
expected: `<|im_start|>system
|
||||||
|
You are a helpful assistant.<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
Tell me something interesting.<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
I'll tell you something interesting about cats`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex tool call arguments should remain json encoded",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "call tool"},
|
||||||
|
{Role: "assistant", ToolCalls: []api.ToolCall{
|
||||||
|
{Function: api.ToolCallFunction{
|
||||||
|
Name: "echo",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"payload": map[string]any{"foo": "bar"},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
|
||||||
|
},
|
||||||
|
expected: `<|im_start|>user
|
||||||
|
call tool<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
<function=echo>
|
||||||
|
<parameter=payload>
|
||||||
|
{"foo":"bar"}
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call><|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
<tool_response>
|
||||||
|
{"payload": {"foo": "bar"}}
|
||||||
|
</tool_response>
|
||||||
|
<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatToolCallArgument(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
arg any
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string",
|
||||||
|
arg: "foo",
|
||||||
|
// notice no quotes around the string
|
||||||
|
expected: "foo",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "map",
|
||||||
|
arg: map[string]any{"foo": "bar"},
|
||||||
|
expected: "{\"foo\":\"bar\"}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "number",
|
||||||
|
arg: 1,
|
||||||
|
expected: "1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "boolean",
|
||||||
|
arg: true,
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := formatToolCallArgument(tt.arg)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
26
model/renderers/renderer.go
Normal file
26
model/renderers/renderer.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error)
|
||||||
|
|
||||||
|
func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||||
|
renderer := rendererForName(name)
|
||||||
|
if renderer == nil {
|
||||||
|
return "", fmt.Errorf("unknown renderer %q", name)
|
||||||
|
}
|
||||||
|
return renderer(msgs, tools, think)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rendererForName(name string) rendererFunc {
|
||||||
|
switch name {
|
||||||
|
case "qwen3-coder":
|
||||||
|
return Qwen3CoderRenderer
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,18 +12,18 @@ import (
|
|||||||
|
|
||||||
const spmWhitespaceSep = "▁"
|
const spmWhitespaceSep = "▁"
|
||||||
|
|
||||||
type SentencePieceModel struct {
|
type SentencePiece struct {
|
||||||
maxTokenLen int
|
maxTokenLen int
|
||||||
vocab *Vocabulary
|
vocab *Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
var _ TextProcessor = (*SentencePiece)(nil)
|
||||||
|
|
||||||
func (spm SentencePieceModel) Vocabulary() *Vocabulary {
|
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||||
return spm.vocab
|
return spm.vocab
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||||
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||||
|
|
||||||
counter := map[int]int{}
|
counter := map[int]int{}
|
||||||
@@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
|||||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||||
"max token len", maxTokenLen)
|
"max token len", maxTokenLen)
|
||||||
|
|
||||||
return SentencePieceModel{
|
return SentencePiece{
|
||||||
maxTokenLen: maxTokenLen,
|
maxTokenLen: maxTokenLen,
|
||||||
vocab: vocab,
|
vocab: vocab,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
func (spm SentencePiece) Is(id int32, special Special) bool {
|
||||||
return spm.vocab.Is(id, special)
|
return spm.vocab.Is(id, special)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
fragments := []fragment{{value: s}}
|
fragments := []fragment{{value: s}}
|
||||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||||
id := spm.vocab.Encode(special)
|
id := spm.vocab.Encode(special)
|
||||||
@@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} {
|
|||||||
return item
|
return item
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
data := spm.vocab.Decode(id)
|
data := spm.vocab.Decode(id)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/ollama/ollama/convert/sentencepiece"
|
"github.com/ollama/ollama/convert/sentencepiece"
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||||
@@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewSentencePieceModel(&v)
|
return NewSentencePiece(&v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSentencePieceEncode(t *testing.T) {
|
func TestSentencePieceEncode(t *testing.T) {
|
||||||
@@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
func TestSentencePieceDecodeByteTokens(t *testing.T) {
|
||||||
vocab := &Vocabulary{
|
vocab := &Vocabulary{
|
||||||
Values: []string{
|
Values: []string{
|
||||||
"normal",
|
"normal",
|
||||||
@@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
|||||||
Scores: []float32{0, 0, 0, 0, 0},
|
Scores: []float32{0, 0, 0, 0, 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
spm := NewSentencePieceModel(vocab)
|
spm := NewSentencePiece(vocab)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
167
model/wordpiece.go
Normal file
167
model/wordpiece.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WordPiece struct {
|
||||||
|
vocab *Vocabulary
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
|
||||||
|
// this differs from original word piece which uses "##" to indicate subwords.
|
||||||
|
const ggmlPrefix = "▁"
|
||||||
|
|
||||||
|
var wordPieceReplacer = strings.NewReplacer(
|
||||||
|
" .", ".",
|
||||||
|
" ?", "?",
|
||||||
|
" !", "!",
|
||||||
|
" ,", ",",
|
||||||
|
" ' ", "'",
|
||||||
|
" n't", "n't",
|
||||||
|
" 'm", "'m",
|
||||||
|
" do not", " don't",
|
||||||
|
" 's", "'s",
|
||||||
|
" 've", "'ve",
|
||||||
|
" 're", "'re",
|
||||||
|
)
|
||||||
|
|
||||||
|
// Decode implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
for i, id := range ids {
|
||||||
|
if id < 0 || int(id) >= len(wpm.vocab.Values) {
|
||||||
|
return "", fmt.Errorf("invalid token id: %d", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
var separator string
|
||||||
|
piece := wpm.vocab.Values[id]
|
||||||
|
if i > 0 &&
|
||||||
|
(strings.HasPrefix(piece, ggmlPrefix) ||
|
||||||
|
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
|
||||||
|
separator = " "
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// words splits a string into words, treating CJK characters as separate words.
|
||||||
|
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
|
||||||
|
func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||||
|
return func(yield func(string) bool) {
|
||||||
|
runes := make([]rune, 0, len(s)*3)
|
||||||
|
for _, r := range s {
|
||||||
|
switch {
|
||||||
|
case r >= 0x4E00 && r <= 0x9FFF,
|
||||||
|
r >= 0x3400 && r <= 0x4DBF,
|
||||||
|
r >= 0x20000 && r <= 0x2A6DF,
|
||||||
|
r >= 0x2A700 && r <= 0x2B73F,
|
||||||
|
r >= 0x2B740 && r <= 0x2B81F,
|
||||||
|
r >= 0x2B820 && r <= 0x2CEAF,
|
||||||
|
r >= 0xF900 && r <= 0xFAFF,
|
||||||
|
r >= 0x2F800 && r <= 0x2FA1F:
|
||||||
|
runes = append(runes, ' ', r, ' ')
|
||||||
|
default:
|
||||||
|
runes = append(runes, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
|
||||||
|
// split on but keep punctuation
|
||||||
|
var start int
|
||||||
|
for start < len(w) {
|
||||||
|
end := strings.IndexFunc(w[start:], unicode.IsPunct)
|
||||||
|
if end < 0 {
|
||||||
|
end = len(w) - start
|
||||||
|
} else if end == 0 {
|
||||||
|
end = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if !yield(w[start : start+end]) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
start += end
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
|
var ids []int32
|
||||||
|
|
||||||
|
// TODO: use [UNK] from config
|
||||||
|
unk := wpm.vocab.Encode("[UNK]")
|
||||||
|
for word := range wpm.words(s) {
|
||||||
|
var start int
|
||||||
|
var pieces []int32
|
||||||
|
for start < len(word) {
|
||||||
|
end := len(word)
|
||||||
|
|
||||||
|
var piece int32
|
||||||
|
for start < end {
|
||||||
|
subword := word[start:end]
|
||||||
|
if start == 0 {
|
||||||
|
subword = ggmlPrefix + subword
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: some models might not want [ToLower]
|
||||||
|
piece = wpm.vocab.Encode(strings.ToLower(subword))
|
||||||
|
if piece >= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
end--
|
||||||
|
}
|
||||||
|
|
||||||
|
if piece < 0 {
|
||||||
|
// Unknown token
|
||||||
|
pieces = pieces[:0]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
pieces = append(pieces, piece)
|
||||||
|
start = end
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pieces) > 0 {
|
||||||
|
ids = append(ids, pieces...)
|
||||||
|
} else {
|
||||||
|
ids = append(ids, unk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if addSpecial && len(ids) > 0 {
|
||||||
|
ids = wpm.vocab.addSpecials(ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||||
|
return wpm.vocab.Is(id, special)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vocabulary implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||||
|
return wpm.vocab
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ TextProcessor = (*WordPiece)(nil)
|
||||||
|
|
||||||
|
func NewWordPiece(vocab *Vocabulary) WordPiece {
|
||||||
|
return WordPiece{
|
||||||
|
vocab: vocab,
|
||||||
|
}
|
||||||
|
}
|
||||||
51
model/wordpiece_test.go
Normal file
51
model/wordpiece_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWordPiece(t *testing.T) {
|
||||||
|
wpm := NewWordPiece(
|
||||||
|
&Vocabulary{
|
||||||
|
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
|
||||||
|
AddBOS: true,
|
||||||
|
AddEOS: true,
|
||||||
|
BOS: []int32{1},
|
||||||
|
EOS: []int32{2},
|
||||||
|
})
|
||||||
|
|
||||||
|
ids, err := wpm.Encode("Hello world!", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
|
||||||
|
t.Errorf("unexpected ids (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
words, err := wpm.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
|
||||||
|
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWordPieceWords(t *testing.T) {
|
||||||
|
var wpm WordPiece
|
||||||
|
|
||||||
|
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||||
|
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||||
|
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
|
||||||
|
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
|
||||||
|
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -105,16 +105,18 @@ type ChatCompletionRequest struct {
|
|||||||
Tools []api.Tool `json:"tools"`
|
Tools []api.Tool `json:"tools"`
|
||||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||||
|
DebugRenderOnly bool `json:"_debug_render_only"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletion struct {
|
type ChatCompletion struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Created int64 `json:"created"`
|
Created int64 `json:"created"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
SystemFingerprint string `json:"system_fingerprint"`
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
Choices []Choice `json:"choices"`
|
Choices []Choice `json:"choices"`
|
||||||
Usage Usage `json:"usage,omitempty"`
|
Usage Usage `json:"usage,omitempty"`
|
||||||
|
DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionChunk struct {
|
type ChatCompletionChunk struct {
|
||||||
@@ -141,6 +143,7 @@ type CompletionRequest struct {
|
|||||||
Temperature *float32 `json:"temperature"`
|
Temperature *float32 `json:"temperature"`
|
||||||
TopP float32 `json:"top_p"`
|
TopP float32 `json:"top_p"`
|
||||||
Suffix string `json:"suffix"`
|
Suffix string `json:"suffix"`
|
||||||
|
DebugRenderOnly bool `json:"_debug_render_only"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Completion struct {
|
type Completion struct {
|
||||||
@@ -273,8 +276,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}}, Usage: toUsage(r),
|
||||||
Usage: toUsage(r),
|
DebugInfo: r.DebugInfo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -568,13 +571,14 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &api.ChatRequest{
|
return &api.ChatRequest{
|
||||||
Model: r.Model,
|
Model: r.Model,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
Format: format,
|
Format: format,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &r.Stream,
|
Stream: &r.Stream,
|
||||||
Tools: r.Tools,
|
Tools: r.Tools,
|
||||||
Think: think,
|
Think: think,
|
||||||
|
DebugRenderOnly: r.DebugRenderOnly,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -648,11 +652,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return api.GenerateRequest{
|
return api.GenerateRequest{
|
||||||
Model: r.Model,
|
Model: r.Model,
|
||||||
Prompt: r.Prompt,
|
Prompt: r.Prompt,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &r.Stream,
|
Stream: &r.Stream,
|
||||||
Suffix: r.Suffix,
|
Suffix: r.Suffix,
|
||||||
|
DebugRenderOnly: r.DebugRenderOnly,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -100,6 +100,10 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||||||
req.System = c.Args
|
req.System = c.Args
|
||||||
case "license":
|
case "license":
|
||||||
licenses = append(licenses, c.Args)
|
licenses = append(licenses, c.Args)
|
||||||
|
case "renderer":
|
||||||
|
req.Renderer = c.Args
|
||||||
|
case "parser":
|
||||||
|
req.Parser = c.Args
|
||||||
case "message":
|
case "message":
|
||||||
role, msg, _ := strings.Cut(c.Args, ": ")
|
role, msg, _ := strings.Cut(c.Args, ": ")
|
||||||
messages = append(messages, api.Message{Role: role, Content: msg})
|
messages = append(messages, api.Message{Role: role, Content: msg})
|
||||||
@@ -320,7 +324,7 @@ func (c Command) String() string {
|
|||||||
switch c.Name {
|
switch c.Name {
|
||||||
case "model":
|
case "model":
|
||||||
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
||||||
case "license", "template", "system", "adapter":
|
case "license", "template", "system", "adapter", "renderer", "parser":
|
||||||
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
|
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
|
||||||
case "message":
|
case "message":
|
||||||
role, message, _ := strings.Cut(c.Args, ": ")
|
role, message, _ := strings.Cut(c.Args, ": ")
|
||||||
@@ -346,7 +350,7 @@ const (
|
|||||||
var (
|
var (
|
||||||
errMissingFrom = errors.New("no FROM line")
|
errMissingFrom = errors.New("no FROM line")
|
||||||
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
|
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
|
||||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
|
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"")
|
||||||
)
|
)
|
||||||
|
|
||||||
type ParserError struct {
|
type ParserError struct {
|
||||||
@@ -606,7 +610,7 @@ func isValidMessageRole(role string) bool {
|
|||||||
|
|
||||||
func isValidCommand(cmd string) bool {
|
func isValidCommand(cmd string) bool {
|
||||||
switch strings.ToLower(cmd) {
|
switch strings.ToLower(cmd) {
|
||||||
case "from", "license", "template", "system", "adapter", "parameter", "message":
|
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message":
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -198,6 +198,34 @@ BADCOMMAND param1 value1
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseFileRenderer(t *testing.T) {
|
||||||
|
input := `
|
||||||
|
FROM foo
|
||||||
|
RENDERER renderer1
|
||||||
|
`
|
||||||
|
|
||||||
|
reader := strings.NewReader(input)
|
||||||
|
|
||||||
|
modelfile, err := ParseFile(reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "renderer", Args: "renderer1"}}, modelfile.Commands)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFileParser(t *testing.T) {
|
||||||
|
input := `
|
||||||
|
FROM foo
|
||||||
|
PARSER parser1
|
||||||
|
`
|
||||||
|
|
||||||
|
reader := strings.NewReader(input)
|
||||||
|
|
||||||
|
modelfile, err := ParseFile(reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "parser", Args: "parser1"}}, modelfile.Commands)
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseFileMessages(t *testing.T) {
|
func TestParseFileMessages(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
input string
|
input string
|
||||||
|
|||||||
@@ -204,13 +204,8 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
|||||||
targetFree = max(targetFree, 1)
|
targetFree = max(targetFree, 1)
|
||||||
|
|
||||||
currentFree := c.numCtx - inputLen
|
currentFree := c.numCtx - inputLen
|
||||||
discard := targetFree - currentFree
|
|
||||||
|
|
||||||
if discard < 0 {
|
return max(targetFree-currentFree, 0)
|
||||||
discard = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return discard
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ErrReprocessInputs struct {
|
type ErrReprocessInputs struct {
|
||||||
|
|||||||
@@ -242,13 +242,8 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
|||||||
targetFree = max(targetFree, 1)
|
targetFree = max(targetFree, 1)
|
||||||
|
|
||||||
currentFree := c.numCtx - inputLen
|
currentFree := c.numCtx - inputLen
|
||||||
discard := targetFree - currentFree
|
|
||||||
|
|
||||||
if discard < 0 {
|
return max(targetFree-currentFree, 0)
|
||||||
discard = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return discard
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ErrReprocessInputs struct {
|
type ErrReprocessInputs struct {
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"image"
|
"image"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -32,6 +31,7 @@ import (
|
|||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
"github.com/ollama/ollama/runner/common"
|
"github.com/ollama/ollama/runner/common"
|
||||||
@@ -405,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
|||||||
func (s *Server) run(ctx context.Context) {
|
func (s *Server) run(ctx context.Context) {
|
||||||
s.ready.Wait()
|
s.ready.Wait()
|
||||||
|
|
||||||
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
|
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
|
||||||
|
|
||||||
var activeBatch batchState
|
var activeBatch batchState
|
||||||
for {
|
for {
|
||||||
@@ -467,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||||||
|
|
||||||
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||||
var batchInputs []*input.Input
|
var batchInputs []*input.Input
|
||||||
|
var batchOutputs []int32
|
||||||
var batch input.Batch
|
var batch input.Batch
|
||||||
|
|
||||||
resumeSeq := -1
|
resumeSeq := -1
|
||||||
@@ -549,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||||
|
|
||||||
seq.iBatch = len(batch.Outputs)
|
seq.iBatch = len(batchOutputs)
|
||||||
if i+1 == len(seq.inputs) {
|
if i+1 == len(seq.inputs) || seq.embeddingOnly {
|
||||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
|
||||||
}
|
}
|
||||||
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||||
@@ -576,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||||||
|
|
||||||
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
||||||
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
||||||
|
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
|
||||||
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("failed to build graph: %w", err)
|
err = fmt.Errorf("failed to build graph: %w", err)
|
||||||
@@ -703,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
|
vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
|
||||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||||
@@ -898,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||||
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
|
if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone {
|
||||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1046,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||||||
batch.Positions[i] = int32(i)
|
batch.Positions[i] = int32(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.Outputs = make([]int32, s.parallel)
|
|
||||||
for i := range batch.Outputs {
|
|
||||||
batch.Outputs[i] = int32(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||||
|
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
|
||||||
|
|
||||||
cache := s.model.Config().Cache
|
cache := s.model.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
|
|||||||
--build-arg=OLLAMA_FAST_BUILD \
|
--build-arg=OLLAMA_FAST_BUILD \
|
||||||
--build-arg=CUSTOM_CPU_FLAGS \
|
--build-arg=CUSTOM_CPU_FLAGS \
|
||||||
--build-arg=GPU_RUNNER_CPU_FLAGS \
|
--build-arg=GPU_RUNNER_CPU_FLAGS \
|
||||||
|
--build-arg=PARALLEL \
|
||||||
--build-arg=AMDGPU_TARGETS"
|
--build-arg=AMDGPU_TARGETS"
|
||||||
|
|
||||||
echo "Building Ollama"
|
echo "Building Ollama"
|
||||||
|
|||||||
150
server/create.go
150
server/create.go
@@ -10,8 +10,11 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -39,6 +42,14 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) CreateHandler(c *gin.Context) {
|
func (s *Server) CreateHandler(c *gin.Context) {
|
||||||
|
config := &ConfigV2{
|
||||||
|
OS: "linux",
|
||||||
|
Architecture: "amd64",
|
||||||
|
RootFS: RootFS{
|
||||||
|
Type: "layers",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
var r api.CreateRequest
|
var r api.CreateRequest
|
||||||
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
|
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
@@ -48,6 +59,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config.Renderer = r.Renderer
|
||||||
|
config.Parser = r.Parser
|
||||||
|
|
||||||
for v := range r.Files {
|
for v := range r.Files {
|
||||||
if !fs.ValidPath(v) {
|
if !fs.ValidPath(v) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
||||||
@@ -77,20 +91,34 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
oldManifest, _ := ParseNamedManifest(name)
|
oldManifest, _ := ParseNamedManifest(name)
|
||||||
|
|
||||||
var baseLayers []*layerGGML
|
var baseLayers []*layerGGML
|
||||||
|
var err error
|
||||||
|
var remote bool
|
||||||
|
|
||||||
if r.From != "" {
|
if r.From != "" {
|
||||||
slog.Debug("create model from model name")
|
slog.Debug("create model from model name", "from", r.From)
|
||||||
fromName := model.ParseName(r.From)
|
fromName := model.ParseName(r.From)
|
||||||
if !fromName.IsValid() {
|
if !fromName.IsValid() {
|
||||||
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
|
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if r.RemoteHost != "" {
|
||||||
|
ru, err := remoteURL(r.RemoteHost)
|
||||||
|
if err != nil {
|
||||||
|
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
config.RemoteModel = r.From
|
||||||
defer cancel()
|
config.RemoteHost = ru
|
||||||
|
remote = true
|
||||||
|
} else {
|
||||||
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
baseLayers, err = parseFromModel(ctx, fromName, fn)
|
baseLayers, err = parseFromModel(ctx, fromName, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if r.Files != nil {
|
} else if r.Files != nil {
|
||||||
baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn)
|
baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn)
|
||||||
@@ -110,7 +138,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var adapterLayers []*layerGGML
|
var adapterLayers []*layerGGML
|
||||||
if r.Adapters != nil {
|
if !remote && r.Adapters != nil {
|
||||||
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
|
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
|
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
|
||||||
@@ -128,7 +156,56 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
baseLayers = append(baseLayers, adapterLayers...)
|
baseLayers = append(baseLayers, adapterLayers...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := createModel(r, name, baseLayers, fn); err != nil {
|
// Info is not currently exposed by Modelfiles, but allows overriding various
|
||||||
|
// config values
|
||||||
|
if r.Info != nil {
|
||||||
|
caps, ok := r.Info["capabilities"]
|
||||||
|
if ok {
|
||||||
|
switch tcaps := caps.(type) {
|
||||||
|
case []any:
|
||||||
|
caps := make([]string, len(tcaps))
|
||||||
|
for i, c := range tcaps {
|
||||||
|
str, ok := c.(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
caps[i] = str
|
||||||
|
}
|
||||||
|
config.Capabilities = append(config.Capabilities, caps...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
strFromInfo := func(k string) string {
|
||||||
|
v, ok := r.Info[k]
|
||||||
|
if ok {
|
||||||
|
val := v.(string)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
vFromInfo := func(k string) float64 {
|
||||||
|
v, ok := r.Info[k]
|
||||||
|
if ok {
|
||||||
|
val := v.(float64)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
config.ModelFamily = strFromInfo("model_family")
|
||||||
|
if config.ModelFamily != "" {
|
||||||
|
config.ModelFamilies = []string{config.ModelFamily}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.BaseName = strFromInfo("base_name")
|
||||||
|
config.FileType = strFromInfo("quantization_level")
|
||||||
|
config.ModelType = strFromInfo("parameter_size")
|
||||||
|
config.ContextLen = int(vFromInfo("context_length"))
|
||||||
|
config.EmbedLen = int(vFromInfo("embedding_length"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := createModel(r, name, baseLayers, config, fn); err != nil {
|
||||||
if errors.Is(err, errBadTemplate) {
|
if errors.Is(err, errBadTemplate) {
|
||||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||||
return
|
return
|
||||||
@@ -154,6 +231,51 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func remoteURL(raw string) (string, error) {
|
||||||
|
// Special‑case: user supplied only a path ("/foo/bar").
|
||||||
|
if strings.HasPrefix(raw, "/") {
|
||||||
|
return (&url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: net.JoinHostPort("localhost", "11434"),
|
||||||
|
Path: path.Clean(raw),
|
||||||
|
}).String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(raw, "://") {
|
||||||
|
raw = "http://" + raw
|
||||||
|
}
|
||||||
|
|
||||||
|
if raw == "ollama.com" || raw == "http://ollama.com" {
|
||||||
|
raw = "https://ollama.com:443"
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Host == "" {
|
||||||
|
u.Host = "localhost"
|
||||||
|
}
|
||||||
|
|
||||||
|
hostPart, portPart, err := net.SplitHostPort(u.Host)
|
||||||
|
if err == nil {
|
||||||
|
u.Host = net.JoinHostPort(hostPart, portPart)
|
||||||
|
} else {
|
||||||
|
u.Host = net.JoinHostPort(u.Host, "11434")
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Path != "" {
|
||||||
|
u.Path = path.Clean(u.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Path == "/" {
|
||||||
|
u.Path = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
|
func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
|
||||||
switch detectModelTypeFromFiles(files) {
|
switch detectModelTypeFromFiles(files) {
|
||||||
case "safetensors":
|
case "safetensors":
|
||||||
@@ -316,15 +438,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
|
|||||||
return ggml.KV{}, fmt.Errorf("no base model was found")
|
return ggml.KV{}, fmt.Errorf("no base model was found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, fn func(resp api.ProgressResponse)) (err error) {
|
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||||
config := ConfigV2{
|
|
||||||
OS: "linux",
|
|
||||||
Architecture: "amd64",
|
|
||||||
RootFS: RootFS{
|
|
||||||
Type: "layers",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var layers []Layer
|
var layers []Layer
|
||||||
for _, layer := range baseLayers {
|
for _, layer := range baseLayers {
|
||||||
if layer.GGML != nil {
|
if layer.GGML != nil {
|
||||||
@@ -404,7 +518,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
configLayer, err := createConfigLayer(layers, config)
|
configLayer, err := createConfigLayer(layers, *config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -104,3 +104,154 @@ func TestConvertFromSafetensors(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRemoteURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
hasError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "absolute path",
|
||||||
|
input: "/foo/bar",
|
||||||
|
expected: "http://localhost:11434/foo/bar",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "absolute path with cleanup",
|
||||||
|
input: "/foo/../bar",
|
||||||
|
expected: "http://localhost:11434/bar",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root path",
|
||||||
|
input: "/",
|
||||||
|
expected: "http://localhost:11434/",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host without scheme",
|
||||||
|
input: "example.com",
|
||||||
|
expected: "http://example.com:11434",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host with port",
|
||||||
|
input: "example.com:8080",
|
||||||
|
expected: "http://example.com:8080",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full URL",
|
||||||
|
input: "https://example.com:8080/path",
|
||||||
|
expected: "https://example.com:8080/path",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full URL with path cleanup",
|
||||||
|
input: "https://example.com:8080/path/../other",
|
||||||
|
expected: "https://example.com:8080/other",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ollama.com special case",
|
||||||
|
input: "ollama.com",
|
||||||
|
expected: "https://ollama.com:443",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "http ollama.com special case",
|
||||||
|
input: "http://ollama.com",
|
||||||
|
expected: "https://ollama.com:443",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL with only host",
|
||||||
|
input: "http://example.com",
|
||||||
|
expected: "http://example.com:11434",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL with root path cleaned",
|
||||||
|
input: "http://example.com/",
|
||||||
|
expected: "http://example.com:11434",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid URL",
|
||||||
|
input: "http://[::1]:namedport", // invalid port
|
||||||
|
expected: "",
|
||||||
|
hasError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expected: "http://localhost:11434",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host with scheme but no port",
|
||||||
|
input: "http://localhost",
|
||||||
|
expected: "http://localhost:11434",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex path cleanup",
|
||||||
|
input: "/a/b/../../c/./d",
|
||||||
|
expected: "http://localhost:11434/c/d",
|
||||||
|
hasError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := remoteURL(tt.input)
|
||||||
|
|
||||||
|
if tt.hasError {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error but got none")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("expected %q, got %q", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoteURL_Idempotent(t *testing.T) {
|
||||||
|
// Test that applying remoteURL twice gives the same result as applying it once
|
||||||
|
testInputs := []string{
|
||||||
|
"/foo/bar",
|
||||||
|
"example.com",
|
||||||
|
"https://example.com:8080/path",
|
||||||
|
"ollama.com",
|
||||||
|
"http://localhost:11434",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, input := range testInputs {
|
||||||
|
t.Run(input, func(t *testing.T) {
|
||||||
|
firstResult, err := remoteURL(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first call failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
secondResult, err := remoteURL(firstResult)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second call failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if firstResult != secondResult {
|
||||||
|
t.Errorf("function is not idempotent: first=%q, second=%q", firstResult, secondResult)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/fs/gguf"
|
"github.com/ollama/ollama/fs/gguf"
|
||||||
|
"github.com/ollama/ollama/model/parsers"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/thinking"
|
"github.com/ollama/ollama/thinking"
|
||||||
@@ -73,29 +74,38 @@ func (m *Model) Capabilities() []model.Capability {
|
|||||||
capabilities := []model.Capability{}
|
capabilities := []model.Capability{}
|
||||||
|
|
||||||
// Check for completion capability
|
// Check for completion capability
|
||||||
f, err := gguf.Open(m.ModelPath)
|
if m.ModelPath != "" {
|
||||||
if err == nil {
|
f, err := gguf.Open(m.ModelPath)
|
||||||
defer f.Close()
|
if err == nil {
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
if f.KeyValue("pooling_type").Valid() {
|
if f.KeyValue("pooling_type").Valid() {
|
||||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||||
|
} else {
|
||||||
|
// If no embedding is specified, we assume the model supports completion
|
||||||
|
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||||
|
}
|
||||||
|
if f.KeyValue("vision.block_count").Valid() {
|
||||||
|
capabilities = append(capabilities, model.CapabilityVision)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// If no embedding is specified, we assume the model supports completion
|
slog.Error("couldn't open model file", "error", err)
|
||||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
|
||||||
}
|
}
|
||||||
if f.KeyValue("vision.block_count").Valid() {
|
} else if len(m.Config.Capabilities) > 0 {
|
||||||
capabilities = append(capabilities, model.CapabilityVision)
|
for _, c := range m.Config.Capabilities {
|
||||||
|
capabilities = append(capabilities, model.Capability(c))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
slog.Error("couldn't open model file", "error", err)
|
slog.Warn("unknown capabilities for model", "model", m.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Template == nil {
|
if m.Template == nil {
|
||||||
return capabilities
|
return capabilities
|
||||||
}
|
}
|
||||||
|
|
||||||
|
builtinParser := parsers.ParserForName(m.Config.Parser)
|
||||||
// Check for tools capability
|
// Check for tools capability
|
||||||
if slices.Contains(m.Template.Vars(), "tools") {
|
if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
|
||||||
capabilities = append(capabilities, model.CapabilityTools)
|
capabilities = append(capabilities, model.CapabilityTools)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,10 +119,16 @@ func (m *Model) Capabilities() []model.Capability {
|
|||||||
capabilities = append(capabilities, model.CapabilityVision)
|
capabilities = append(capabilities, model.CapabilityVision)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip the thinking check if it's already set
|
||||||
|
if slices.Contains(capabilities, "thinking") {
|
||||||
|
return capabilities
|
||||||
|
}
|
||||||
|
|
||||||
// Check for thinking capability
|
// Check for thinking capability
|
||||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||||
hasTags := openingTag != "" && closingTag != ""
|
hasTags := openingTag != "" && closingTag != ""
|
||||||
if hasTags || slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) {
|
isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily)
|
||||||
|
if hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport()) {
|
||||||
capabilities = append(capabilities, model.CapabilityThinking)
|
capabilities = append(capabilities, model.CapabilityThinking)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,6 +214,20 @@ func (m *Model) String() string {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.Config.Renderer != "" {
|
||||||
|
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||||
|
Name: "renderer",
|
||||||
|
Args: m.Config.Renderer,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Config.Parser != "" {
|
||||||
|
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||||
|
Name: "parser",
|
||||||
|
Args: m.Config.Parser,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
for k, v := range m.Options {
|
for k, v := range m.Options {
|
||||||
switch v := v.(type) {
|
switch v := v.(type) {
|
||||||
case []any:
|
case []any:
|
||||||
@@ -236,8 +266,19 @@ type ConfigV2 struct {
|
|||||||
ModelFormat string `json:"model_format"`
|
ModelFormat string `json:"model_format"`
|
||||||
ModelFamily string `json:"model_family"`
|
ModelFamily string `json:"model_family"`
|
||||||
ModelFamilies []string `json:"model_families"`
|
ModelFamilies []string `json:"model_families"`
|
||||||
ModelType string `json:"model_type"`
|
ModelType string `json:"model_type"` // shown as Parameter Size
|
||||||
FileType string `json:"file_type"`
|
FileType string `json:"file_type"` // shown as Quantization Level
|
||||||
|
Renderer string `json:"renderer,omitempty"`
|
||||||
|
Parser string `json:"parser,omitempty"`
|
||||||
|
|
||||||
|
RemoteHost string `json:"remote_host,omitempty"`
|
||||||
|
RemoteModel string `json:"remote_model,omitempty"`
|
||||||
|
|
||||||
|
// used for remotes
|
||||||
|
Capabilities []string `json:"capabilities,omitempty"`
|
||||||
|
ContextLen int `json:"context_length,omitempty"`
|
||||||
|
EmbedLen int `json:"embedding_length,omitempty"`
|
||||||
|
BaseName string `json:"base_name,omitempty"`
|
||||||
|
|
||||||
// required by spec
|
// required by spec
|
||||||
Architecture string `json:"architecture"`
|
Architecture string `json:"architecture"`
|
||||||
|
|||||||
@@ -25,10 +25,7 @@ func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
|
|||||||
|
|
||||||
// n^2 backoff timer is a little smoother than the
|
// n^2 backoff timer is a little smoother than the
|
||||||
// common choice of 2^n.
|
// common choice of 2^n.
|
||||||
d := time.Duration(n*n) * 10 * time.Millisecond
|
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
|
||||||
if d > maxBackoff {
|
|
||||||
d = maxBackoff
|
|
||||||
}
|
|
||||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||||
// to prevent accidental "thundering herd" problems.
|
// to prevent accidental "thundering herd" problems.
|
||||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
|
"github.com/ollama/ollama/model/renderers"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -41,18 +42,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
thinkVal := false
|
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||||
thinkLevel := ""
|
if err != nil {
|
||||||
if think != nil {
|
|
||||||
thinkVal = think.Bool()
|
|
||||||
thinkLevel = think.String()
|
|
||||||
}
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
|
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := tokenize(ctx, b.String())
|
s, err := tokenize(ctx, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
@@ -101,6 +96,23 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// truncate any messages that do not fit into the context window
|
// truncate any messages that do not fit into the context window
|
||||||
|
p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, images, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||||
|
if m.Config.Renderer != "" {
|
||||||
|
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return rendered, nil
|
||||||
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
thinkVal := false
|
thinkVal := false
|
||||||
thinkLevel := ""
|
thinkLevel := ""
|
||||||
@@ -108,9 +120,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
thinkVal = think.Bool()
|
thinkVal = think.Bool()
|
||||||
thinkLevel = think.String()
|
thinkLevel = think.String()
|
||||||
}
|
}
|
||||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
|
if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
|
||||||
return "", nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
|
return b.String(), nil
|
||||||
return b.String(), images, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
354
server/routes.go
354
server/routes.go
@@ -15,6 +15,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -28,6 +29,7 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/discover"
|
"github.com/ollama/ollama/discover"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
@@ -35,6 +37,7 @@ import (
|
|||||||
"github.com/ollama/ollama/harmony"
|
"github.com/ollama/ollama/harmony"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
|
"github.com/ollama/ollama/model/parsers"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/registry"
|
"github.com/ollama/ollama/server/internal/registry"
|
||||||
@@ -188,6 +191,87 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||||
|
origModel := req.Model
|
||||||
|
|
||||||
|
remoteURL, err := url.Parse(m.Config.RemoteHost)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
|
||||||
|
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Model = m.Config.RemoteModel
|
||||||
|
|
||||||
|
if req.Template == "" && m.Template.String() != "" {
|
||||||
|
req.Template = m.Template.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Options == nil {
|
||||||
|
req.Options = map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range m.Options {
|
||||||
|
if _, ok := req.Options[k]; !ok {
|
||||||
|
req.Options[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// update the system prompt from the model if one isn't already specified
|
||||||
|
if req.System == "" && m.System != "" {
|
||||||
|
req.System = m.System
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.Messages) > 0 {
|
||||||
|
slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn := func(resp api.GenerateResponse) error {
|
||||||
|
resp.Model = origModel
|
||||||
|
resp.RemoteModel = m.Config.RemoteModel
|
||||||
|
resp.RemoteHost = m.Config.RemoteHost
|
||||||
|
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.Writer.Flush()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||||
|
err = client.Generate(c, &req, fn)
|
||||||
|
if err != nil {
|
||||||
|
var sErr api.AuthorizationError
|
||||||
|
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
||||||
|
pk, pkErr := auth.GetPublicKey()
|
||||||
|
if pkErr != nil {
|
||||||
|
slog.Error("couldn't get public key", "error", pkErr)
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"public_key": pk,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// expire the runner
|
// expire the runner
|
||||||
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||||
s.sched.expireRunner(m)
|
s.sched.expireRunner(m)
|
||||||
@@ -329,10 +413,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
|
|
||||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||||
if req.DebugRenderOnly {
|
if req.DebugRenderOnly {
|
||||||
c.JSON(http.StatusOK, api.DebugTemplateResponse{
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
DebugInfo: api.DebugInfo{
|
DebugInfo: &api.DebugInfo{
|
||||||
RenderedTemplate: prompt,
|
RenderedTemplate: prompt,
|
||||||
ImageCount: len(images),
|
ImageCount: len(images),
|
||||||
},
|
},
|
||||||
@@ -348,6 +432,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
OpeningTag: openingTag,
|
OpeningTag: openingTag,
|
||||||
ClosingTag: closingTag,
|
ClosingTag: closingTag,
|
||||||
}
|
}
|
||||||
|
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
|
||||||
|
thinkingState.AddContent(openingTag)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -488,7 +575,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
truncate := true
|
truncate := true
|
||||||
|
|
||||||
if req.Truncate != nil && !*req.Truncate {
|
if req.Truncate != nil && !*req.Truncate {
|
||||||
truncate = false
|
truncate = false
|
||||||
}
|
}
|
||||||
@@ -551,11 +637,27 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||||
if len(tokens) > ctxLen {
|
if len(tokens) > ctxLen {
|
||||||
if !truncate {
|
if !truncate {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
||||||
|
ctxLen--
|
||||||
|
}
|
||||||
|
|
||||||
|
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
||||||
|
ctxLen--
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
|
||||||
|
if ctxLen <= 0 {
|
||||||
|
// return error if the truncated input would be empty or just special tokens
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens = tokens[:ctxLen]
|
tokens = tokens[:ctxLen]
|
||||||
|
|
||||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -922,6 +1024,28 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
ModifiedAt: manifest.fi.ModTime(),
|
ModifiedAt: manifest.fi.ModTime(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.Config.RemoteHost != "" {
|
||||||
|
resp.RemoteHost = m.Config.RemoteHost
|
||||||
|
resp.RemoteModel = m.Config.RemoteModel
|
||||||
|
|
||||||
|
if m.Config.ModelFamily != "" {
|
||||||
|
resp.ModelInfo = make(map[string]any)
|
||||||
|
resp.ModelInfo["general.architecture"] = m.Config.ModelFamily
|
||||||
|
|
||||||
|
if m.Config.BaseName != "" {
|
||||||
|
resp.ModelInfo["general.basename"] = m.Config.BaseName
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Config.ContextLen > 0 {
|
||||||
|
resp.ModelInfo[fmt.Sprintf("%s.context_length", m.Config.ModelFamily)] = m.Config.ContextLen
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Config.EmbedLen > 0 {
|
||||||
|
resp.ModelInfo[fmt.Sprintf("%s.embedding_length", m.Config.ModelFamily)] = m.Config.EmbedLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var params []string
|
var params []string
|
||||||
cs := 30
|
cs := 30
|
||||||
for k, v := range m.Options {
|
for k, v := range m.Options {
|
||||||
@@ -952,6 +1076,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
fmt.Fprint(&sb, m.String())
|
fmt.Fprint(&sb, m.String())
|
||||||
resp.Modelfile = sb.String()
|
resp.Modelfile = sb.String()
|
||||||
|
|
||||||
|
// skip loading tensor information if this is a remote model
|
||||||
|
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1028,11 +1157,13 @@ func (s *Server) ListHandler(c *gin.Context) {
|
|||||||
|
|
||||||
// tag should never be masked
|
// tag should never be masked
|
||||||
models = append(models, api.ListModelResponse{
|
models = append(models, api.ListModelResponse{
|
||||||
Model: n.DisplayShortest(),
|
Model: n.DisplayShortest(),
|
||||||
Name: n.DisplayShortest(),
|
Name: n.DisplayShortest(),
|
||||||
Size: m.Size(),
|
RemoteModel: cf.RemoteModel,
|
||||||
Digest: m.digest,
|
RemoteHost: cf.RemoteHost,
|
||||||
ModifiedAt: m.fi.ModTime(),
|
Size: m.Size(),
|
||||||
|
Digest: m.digest,
|
||||||
|
ModifiedAt: m.fi.ModTime(),
|
||||||
Details: api.ModelDetails{
|
Details: api.ModelDetails{
|
||||||
Format: cf.ModelFormat,
|
Format: cf.ModelFormat,
|
||||||
Family: cf.ModelFamily,
|
Family: cf.ModelFamily,
|
||||||
@@ -1292,6 +1423,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.POST("/api/show", s.ShowHandler)
|
r.POST("/api/show", s.ShowHandler)
|
||||||
r.DELETE("/api/delete", s.DeleteHandler)
|
r.DELETE("/api/delete", s.DeleteHandler)
|
||||||
|
|
||||||
|
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
|
||||||
|
r.POST("/api/me", s.WhoamiHandler)
|
||||||
|
|
||||||
// Create
|
// Create
|
||||||
r.POST("/api/create", s.CreateHandler)
|
r.POST("/api/create", s.CreateHandler)
|
||||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||||
@@ -1488,6 +1622,49 @@ func streamResponse(c *gin.Context, ch chan any) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) WhoamiHandler(c *gin.Context) {
|
||||||
|
// todo allow other hosts
|
||||||
|
u, err := url.Parse("https://ollama.com")
|
||||||
|
if err != nil {
|
||||||
|
slog.Error(err.Error())
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client := api.NewClient(u, http.DefaultClient)
|
||||||
|
user, err := client.Whoami(c)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error(err.Error())
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) SignoutHandler(c *gin.Context) {
|
||||||
|
encodedKey := c.Param("encodedKey")
|
||||||
|
|
||||||
|
// todo allow other hosts
|
||||||
|
u, err := url.Parse("https://ollama.com")
|
||||||
|
if err != nil {
|
||||||
|
slog.Error(err.Error())
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client := api.NewClient(u, http.DefaultClient)
|
||||||
|
err = client.Signout(c, encodedKey)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error(err.Error())
|
||||||
|
if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, nil)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) PsHandler(c *gin.Context) {
|
func (s *Server) PsHandler(c *gin.Context) {
|
||||||
models := []api.ProcessModelResponse{}
|
models := []api.ProcessModelResponse{}
|
||||||
|
|
||||||
@@ -1544,21 +1721,34 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// expire the runner
|
name := model.ParseName(req.Model)
|
||||||
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
if !name.IsValid() {
|
||||||
model, err := GetModel(req.Model)
|
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||||
if err != nil {
|
return
|
||||||
switch {
|
}
|
||||||
case os.IsNotExist(err):
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
name, err := getExistingName(name)
|
||||||
case err.Error() == errtypes.InvalidModelNameErrMsg:
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||||
default:
|
return
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
}
|
||||||
}
|
|
||||||
return
|
m, err := GetModel(req.Model)
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case os.IsNotExist(err):
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||||
|
case err.Error() == errtypes.InvalidModelNameErrMsg:
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
}
|
}
|
||||||
s.sched.expireRunner(model)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// expire the runner
|
||||||
|
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||||
|
s.sched.expireRunner(m)
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.ChatResponse{
|
c.JSON(http.StatusOK, api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@@ -1570,6 +1760,80 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||||
|
origModel := req.Model
|
||||||
|
|
||||||
|
remoteURL, err := url.Parse(m.Config.RemoteHost)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
|
||||||
|
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Model = m.Config.RemoteModel
|
||||||
|
if req.Options == nil {
|
||||||
|
req.Options = map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := append(m.Messages, req.Messages...)
|
||||||
|
if req.Messages[0].Role != "system" && m.System != "" {
|
||||||
|
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
||||||
|
}
|
||||||
|
msgs = filterThinkTags(msgs, m)
|
||||||
|
req.Messages = msgs
|
||||||
|
|
||||||
|
for k, v := range m.Options {
|
||||||
|
if _, ok := req.Options[k]; !ok {
|
||||||
|
req.Options[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn := func(resp api.ChatResponse) error {
|
||||||
|
resp.Model = origModel
|
||||||
|
resp.RemoteModel = m.Config.RemoteModel
|
||||||
|
resp.RemoteHost = m.Config.RemoteHost
|
||||||
|
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.Writer.Flush()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||||
|
err = client.Chat(c, &req, fn)
|
||||||
|
if err != nil {
|
||||||
|
var sErr api.AuthorizationError
|
||||||
|
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
||||||
|
pk, pkErr := auth.GetPublicKey()
|
||||||
|
if pkErr != nil {
|
||||||
|
slog.Error("couldn't get public key", "error", pkErr)
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"public_key": pk,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
caps := []model.Capability{model.CapabilityCompletion}
|
caps := []model.Capability{model.CapabilityCompletion}
|
||||||
if len(req.Tools) > 0 {
|
if len(req.Tools) > 0 {
|
||||||
caps = append(caps, model.CapabilityTools)
|
caps = append(caps, model.CapabilityTools)
|
||||||
@@ -1578,17 +1842,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
caps = append(caps, model.CapabilityThinking)
|
caps = append(caps, model.CapabilityThinking)
|
||||||
}
|
}
|
||||||
|
|
||||||
name := model.ParseName(req.Model)
|
|
||||||
if !name.IsValid() {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
name, err := getExistingName(name)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||||
if errors.Is(err, errCapabilityCompletion) {
|
if errors.Is(err, errCapabilityCompletion) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||||
@@ -1617,10 +1870,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
msgs = filterThinkTags(msgs, m)
|
msgs = filterThinkTags(msgs, m)
|
||||||
|
|
||||||
|
var builtinParser parsers.Parser
|
||||||
|
if m.Config.Parser != "" {
|
||||||
|
builtinParser = parsers.ParserForName(m.Config.Parser)
|
||||||
|
}
|
||||||
|
|
||||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||||
|
|
||||||
useHarmony := shouldUseHarmony(m)
|
useHarmony := shouldUseHarmony(m) || m.Config.Parser == "harmony"
|
||||||
|
|
||||||
processedTools := req.Tools
|
processedTools := req.Tools
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
@@ -1650,10 +1908,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
|
|
||||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||||
if req.DebugRenderOnly {
|
if req.DebugRenderOnly {
|
||||||
c.JSON(http.StatusOK, api.DebugTemplateResponse{
|
c.JSON(http.StatusOK, api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
DebugInfo: api.DebugInfo{
|
DebugInfo: &api.DebugInfo{
|
||||||
RenderedTemplate: prompt,
|
RenderedTemplate: prompt,
|
||||||
ImageCount: len(images),
|
ImageCount: len(images),
|
||||||
},
|
},
|
||||||
@@ -1713,6 +1971,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(drifkin): fold this as much as possibleinto the generic m.Config.Parser logic
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
|
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
|
||||||
res.Message.Content = content
|
res.Message.Content = content
|
||||||
@@ -1739,6 +1998,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
ch <- res
|
ch <- res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
} else if builtinParser != nil {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
|
||||||
|
|
||||||
|
content, thinking, toolCalls, err := builtinParser.Add(r.Content, req.Tools)
|
||||||
|
if err != nil {
|
||||||
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Message.Content = content
|
||||||
|
res.Message.Thinking = thinking
|
||||||
|
res.Message.ToolCalls = toolCalls
|
||||||
|
|
||||||
|
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
|
||||||
|
ch <- res
|
||||||
|
} else {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -20,6 +21,7 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
var stream bool = false
|
var stream bool = false
|
||||||
@@ -615,6 +617,78 @@ func TestCreateTemplateSystem(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateAndShowRemoteModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Model: "test",
|
||||||
|
From: "bob",
|
||||||
|
RemoteHost: "https://ollama.com",
|
||||||
|
Info: map[string]any{
|
||||||
|
"capabilities": []string{"completion", "tools", "thinking"},
|
||||||
|
"model_family": "gptoss",
|
||||||
|
"context_length": 131072,
|
||||||
|
"embedding_length": 2880,
|
||||||
|
"quantization_level": "MXFP4",
|
||||||
|
"parameter_size": "20.9B",
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("exected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test"})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("exected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.ShowResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedDetails := api.ModelDetails{
|
||||||
|
ParentModel: "",
|
||||||
|
Format: "",
|
||||||
|
Family: "gptoss",
|
||||||
|
Families: []string{"gptoss"},
|
||||||
|
ParameterSize: "20.9B",
|
||||||
|
QuantizationLevel: "MXFP4",
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(resp.Details, expectedDetails) {
|
||||||
|
t.Errorf("model details: expected %#v, actual %#v", expectedDetails, resp.Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedCaps := []model.Capability{
|
||||||
|
model.Capability("completion"),
|
||||||
|
model.Capability("tools"),
|
||||||
|
model.Capability("thinking"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Equal(resp.Capabilities, expectedCaps) {
|
||||||
|
t.Errorf("capabilities: expected %#v, actual %#v", expectedCaps, resp.Capabilities)
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok := resp.ModelInfo["gptoss.context_length"]
|
||||||
|
ctxlen := v.(float64)
|
||||||
|
if !ok || int(ctxlen) != 131072 {
|
||||||
|
t.Errorf("context len: expected %d, actual %d", 131072, int(ctxlen))
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = resp.ModelInfo["gptoss.embedding_length"]
|
||||||
|
embedlen := v.(float64)
|
||||||
|
if !ok || int(embedlen) != 2880 {
|
||||||
|
t.Errorf("embed len: expected %d, actual %d", 2880, int(embedlen))
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("resp = %#v\n", resp)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateLicenses(t *testing.T) {
|
func TestCreateLicenses(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
|||||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
var response api.DebugTemplateResponse
|
var response api.GenerateResponse
|
||||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
t.Fatalf("failed to unmarshal response: %v", err)
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -385,7 +385,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
|||||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
var response api.DebugTemplateResponse
|
var response api.ChatResponse
|
||||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
t.Fatalf("failed to unmarshal response: %v", err)
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -126,7 +126,15 @@ func TestRoutes(t *testing.T) {
|
|||||||
t.Fatalf("failed to create model: %v", err)
|
t.Fatalf("failed to create model: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := createModel(r, modelName, baseLayers, fn); err != nil {
|
config := &ConfigV2{
|
||||||
|
OS: "linux",
|
||||||
|
Architecture: "amd64",
|
||||||
|
RootFS: RootFS{
|
||||||
|
Type: "layers",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := createModel(r, modelName, baseLayers, config, fn); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -382,10 +382,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
|
|||||||
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
|
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
|
||||||
// (if any). Returns whether the scheduler needs to evict a model to make this one fit.
|
// (if any). Returns whether the scheduler needs to evict a model to make this one fit.
|
||||||
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool {
|
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool {
|
||||||
numParallel := int(envconfig.NumParallel())
|
numParallel := max(int(envconfig.NumParallel()), 1)
|
||||||
if numParallel < 1 {
|
|
||||||
numParallel = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Embedding models should always be loaded with parallel=1
|
// Embedding models should always be loaded with parallel=1
|
||||||
if req.model.CheckCapabilities(model.CapabilityCompletion) != nil {
|
if req.model.CheckCapabilities(model.CapabilityCompletion) != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user