mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
engine: add remote proxy (#12307)
This commit is contained in:
152
server/create.go
152
server/create.go
@@ -10,8 +10,11 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -39,6 +42,14 @@ var (
|
||||
)
|
||||
|
||||
func (s *Server) CreateHandler(c *gin.Context) {
|
||||
config := &ConfigV2{
|
||||
OS: "linux",
|
||||
Architecture: "amd64",
|
||||
RootFS: RootFS{
|
||||
Type: "layers",
|
||||
},
|
||||
}
|
||||
|
||||
var r api.CreateRequest
|
||||
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
@@ -48,6 +59,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
config.Renderer = r.Renderer
|
||||
config.Parser = r.Parser
|
||||
|
||||
for v := range r.Files {
|
||||
if !fs.ValidPath(v) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
||||
@@ -77,20 +91,34 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
oldManifest, _ := ParseNamedManifest(name)
|
||||
|
||||
var baseLayers []*layerGGML
|
||||
var err error
|
||||
var remote bool
|
||||
|
||||
if r.From != "" {
|
||||
slog.Debug("create model from model name")
|
||||
slog.Debug("create model from model name", "from", r.From)
|
||||
fromName := model.ParseName(r.From)
|
||||
if !fromName.IsValid() {
|
||||
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
if r.RemoteHost != "" {
|
||||
ru, err := remoteURL(r.RemoteHost)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
config.RemoteModel = r.From
|
||||
config.RemoteHost = ru
|
||||
remote = true
|
||||
} else {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
baseLayers, err = parseFromModel(ctx, fromName, fn)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
baseLayers, err = parseFromModel(ctx, fromName, fn)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}
|
||||
} else if r.Files != nil {
|
||||
baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn)
|
||||
@@ -110,7 +138,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
var adapterLayers []*layerGGML
|
||||
if r.Adapters != nil {
|
||||
if !remote && r.Adapters != nil {
|
||||
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
|
||||
if err != nil {
|
||||
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
|
||||
@@ -128,7 +156,56 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
baseLayers = append(baseLayers, adapterLayers...)
|
||||
}
|
||||
|
||||
if err := createModel(r, name, baseLayers, fn); err != nil {
|
||||
// Info is not currently exposed by Modelfiles, but allows overriding various
|
||||
// config values
|
||||
if r.Info != nil {
|
||||
caps, ok := r.Info["capabilities"]
|
||||
if ok {
|
||||
switch tcaps := caps.(type) {
|
||||
case []any:
|
||||
caps := make([]string, len(tcaps))
|
||||
for i, c := range tcaps {
|
||||
str, ok := c.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
caps[i] = str
|
||||
}
|
||||
config.Capabilities = append(config.Capabilities, caps...)
|
||||
}
|
||||
}
|
||||
|
||||
strFromInfo := func(k string) string {
|
||||
v, ok := r.Info[k]
|
||||
if ok {
|
||||
val := v.(string)
|
||||
return val
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
vFromInfo := func(k string) float64 {
|
||||
v, ok := r.Info[k]
|
||||
if ok {
|
||||
val := v.(float64)
|
||||
return val
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
config.ModelFamily = strFromInfo("model_family")
|
||||
if config.ModelFamily != "" {
|
||||
config.ModelFamilies = []string{config.ModelFamily}
|
||||
}
|
||||
|
||||
config.BaseName = strFromInfo("base_name")
|
||||
config.FileType = strFromInfo("quantization_level")
|
||||
config.ModelType = strFromInfo("parameter_size")
|
||||
config.ContextLen = int(vFromInfo("context_length"))
|
||||
config.EmbedLen = int(vFromInfo("embedding_length"))
|
||||
}
|
||||
|
||||
if err := createModel(r, name, baseLayers, config, fn); err != nil {
|
||||
if errors.Is(err, errBadTemplate) {
|
||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||
return
|
||||
@@ -154,6 +231,51 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func remoteURL(raw string) (string, error) {
|
||||
// Special‑case: user supplied only a path ("/foo/bar").
|
||||
if strings.HasPrefix(raw, "/") {
|
||||
return (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("localhost", "11434"),
|
||||
Path: path.Clean(raw),
|
||||
}).String(), nil
|
||||
}
|
||||
|
||||
if !strings.Contains(raw, "://") {
|
||||
raw = "http://" + raw
|
||||
}
|
||||
|
||||
if raw == "ollama.com" || raw == "http://ollama.com" {
|
||||
raw = "https://ollama.com:443"
|
||||
}
|
||||
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse error: %w", err)
|
||||
}
|
||||
|
||||
if u.Host == "" {
|
||||
u.Host = "localhost"
|
||||
}
|
||||
|
||||
hostPart, portPart, err := net.SplitHostPort(u.Host)
|
||||
if err == nil {
|
||||
u.Host = net.JoinHostPort(hostPart, portPart)
|
||||
} else {
|
||||
u.Host = net.JoinHostPort(u.Host, "11434")
|
||||
}
|
||||
|
||||
if u.Path != "" {
|
||||
u.Path = path.Clean(u.Path)
|
||||
}
|
||||
|
||||
if u.Path == "/" {
|
||||
u.Path = ""
|
||||
}
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
|
||||
switch detectModelTypeFromFiles(files) {
|
||||
case "safetensors":
|
||||
@@ -316,17 +438,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
|
||||
return ggml.KV{}, fmt.Errorf("no base model was found")
|
||||
}
|
||||
|
||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, fn func(resp api.ProgressResponse)) (err error) {
|
||||
config := ConfigV2{
|
||||
OS: "linux",
|
||||
Architecture: "amd64",
|
||||
RootFS: RootFS{
|
||||
Type: "layers",
|
||||
},
|
||||
Renderer: r.Renderer,
|
||||
Parser: r.Parser,
|
||||
}
|
||||
|
||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||
var layers []Layer
|
||||
for _, layer := range baseLayers {
|
||||
if layer.GGML != nil {
|
||||
@@ -406,7 +518,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
return err
|
||||
}
|
||||
|
||||
configLayer, err := createConfigLayer(layers, config)
|
||||
configLayer, err := createConfigLayer(layers, *config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -104,3 +104,154 @@ func TestConvertFromSafetensors(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "absolute path",
|
||||
input: "/foo/bar",
|
||||
expected: "http://localhost:11434/foo/bar",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "absolute path with cleanup",
|
||||
input: "/foo/../bar",
|
||||
expected: "http://localhost:11434/bar",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "root path",
|
||||
input: "/",
|
||||
expected: "http://localhost:11434/",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "host without scheme",
|
||||
input: "example.com",
|
||||
expected: "http://example.com:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "host with port",
|
||||
input: "example.com:8080",
|
||||
expected: "http://example.com:8080",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "full URL",
|
||||
input: "https://example.com:8080/path",
|
||||
expected: "https://example.com:8080/path",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "full URL with path cleanup",
|
||||
input: "https://example.com:8080/path/../other",
|
||||
expected: "https://example.com:8080/other",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "ollama.com special case",
|
||||
input: "ollama.com",
|
||||
expected: "https://ollama.com:443",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "http ollama.com special case",
|
||||
input: "http://ollama.com",
|
||||
expected: "https://ollama.com:443",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "URL with only host",
|
||||
input: "http://example.com",
|
||||
expected: "http://example.com:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "URL with root path cleaned",
|
||||
input: "http://example.com/",
|
||||
expected: "http://example.com:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid URL",
|
||||
input: "http://[::1]:namedport", // invalid port
|
||||
expected: "",
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "http://localhost:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "host with scheme but no port",
|
||||
input: "http://localhost",
|
||||
expected: "http://localhost:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "complex path cleanup",
|
||||
input: "/a/b/../../c/./d",
|
||||
expected: "http://localhost:11434/c/d",
|
||||
hasError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := remoteURL(tt.input)
|
||||
|
||||
if tt.hasError {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteURL_Idempotent(t *testing.T) {
|
||||
// Test that applying remoteURL twice gives the same result as applying it once
|
||||
testInputs := []string{
|
||||
"/foo/bar",
|
||||
"example.com",
|
||||
"https://example.com:8080/path",
|
||||
"ollama.com",
|
||||
"http://localhost:11434",
|
||||
}
|
||||
|
||||
for _, input := range testInputs {
|
||||
t.Run(input, func(t *testing.T) {
|
||||
firstResult, err := remoteURL(input)
|
||||
if err != nil {
|
||||
t.Fatalf("first call failed: %v", err)
|
||||
}
|
||||
|
||||
secondResult, err := remoteURL(firstResult)
|
||||
if err != nil {
|
||||
t.Fatalf("second call failed: %v", err)
|
||||
}
|
||||
|
||||
if firstResult != secondResult {
|
||||
t.Errorf("function is not idempotent: first=%q, second=%q", firstResult, secondResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,21 +74,29 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for completion capability
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
if m.ModelPath != "" {
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
|
||||
if f.KeyValue("pooling_type").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
if f.KeyValue("pooling_type").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
} else {
|
||||
// If no embedding is specified, we assume the model supports completion
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
if f.KeyValue("vision.block_count").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
} else {
|
||||
// If no embedding is specified, we assume the model supports completion
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
}
|
||||
if f.KeyValue("vision.block_count").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
} else if len(m.Config.Capabilities) > 0 {
|
||||
for _, c := range m.Config.Capabilities {
|
||||
capabilities = append(capabilities, model.Capability(c))
|
||||
}
|
||||
} else {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
slog.Warn("unknown capabilities for model", "model", m.Name)
|
||||
}
|
||||
|
||||
if m.Template == nil {
|
||||
@@ -111,6 +119,11 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
|
||||
// Skip the thinking check if it's already set
|
||||
if slices.Contains(capabilities, "thinking") {
|
||||
return capabilities
|
||||
}
|
||||
|
||||
// Check for thinking capability
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
hasTags := openingTag != "" && closingTag != ""
|
||||
@@ -253,11 +266,20 @@ type ConfigV2 struct {
|
||||
ModelFormat string `json:"model_format"`
|
||||
ModelFamily string `json:"model_family"`
|
||||
ModelFamilies []string `json:"model_families"`
|
||||
ModelType string `json:"model_type"`
|
||||
FileType string `json:"file_type"`
|
||||
ModelType string `json:"model_type"` // shown as Parameter Size
|
||||
FileType string `json:"file_type"` // shown as Quantization Level
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
|
||||
// used for remotes
|
||||
Capabilities []string `json:"capabilities,omitempty"`
|
||||
ContextLen int `json:"context_length,omitempty"`
|
||||
EmbedLen int `json:"embedding_length,omitempty"`
|
||||
BaseName string `json:"base_name,omitempty"`
|
||||
|
||||
// required by spec
|
||||
Architecture string `json:"architecture"`
|
||||
OS string `json:"os"`
|
||||
|
||||
277
server/routes.go
277
server/routes.go
@@ -15,6 +15,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"slices"
|
||||
@@ -28,6 +29,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
@@ -189,6 +191,84 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
origModel := req.Model
|
||||
|
||||
remoteURL, err := url.Parse(m.Config.RemoteHost)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
|
||||
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
|
||||
return
|
||||
}
|
||||
|
||||
req.Model = m.Config.RemoteModel
|
||||
|
||||
if req.Template == "" && m.Template.String() != "" {
|
||||
req.Template = m.Template.String()
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
|
||||
for k, v := range m.Options {
|
||||
if _, ok := req.Options[k]; !ok {
|
||||
req.Options[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// update the system prompt from the model if one isn't already specified
|
||||
if req.System == "" && m.System != "" {
|
||||
req.System = m.System
|
||||
}
|
||||
|
||||
if len(m.Messages) > 0 {
|
||||
slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead")
|
||||
}
|
||||
|
||||
fn := func(resp api.GenerateResponse) error {
|
||||
resp.Model = origModel
|
||||
resp.RemoteModel = m.Config.RemoteModel
|
||||
resp.RemoteHost = m.Config.RemoteHost
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Writer.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||
err = client.Generate(c, &req, fn)
|
||||
if err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
||||
pk, pkErr := auth.GetPublicKey()
|
||||
if pkErr != nil {
|
||||
slog.Error("couldn't get public key", "error", pkErr)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"public_key": pk})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
@@ -931,6 +1011,28 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" {
|
||||
resp.RemoteHost = m.Config.RemoteHost
|
||||
resp.RemoteModel = m.Config.RemoteModel
|
||||
|
||||
if m.Config.ModelFamily != "" {
|
||||
resp.ModelInfo = make(map[string]any)
|
||||
resp.ModelInfo["general.architecture"] = m.Config.ModelFamily
|
||||
|
||||
if m.Config.BaseName != "" {
|
||||
resp.ModelInfo["general.basename"] = m.Config.BaseName
|
||||
}
|
||||
|
||||
if m.Config.ContextLen > 0 {
|
||||
resp.ModelInfo[fmt.Sprintf("%s.context_length", m.Config.ModelFamily)] = m.Config.ContextLen
|
||||
}
|
||||
|
||||
if m.Config.EmbedLen > 0 {
|
||||
resp.ModelInfo[fmt.Sprintf("%s.embedding_length", m.Config.ModelFamily)] = m.Config.EmbedLen
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var params []string
|
||||
cs := 30
|
||||
for k, v := range m.Options {
|
||||
@@ -961,6 +1063,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
fmt.Fprint(&sb, m.String())
|
||||
resp.Modelfile = sb.String()
|
||||
|
||||
// skip loading tensor information if this is a remote model
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1037,11 +1144,13 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
|
||||
// tag should never be masked
|
||||
models = append(models, api.ListModelResponse{
|
||||
Model: n.DisplayShortest(),
|
||||
Name: n.DisplayShortest(),
|
||||
Size: m.Size(),
|
||||
Digest: m.digest,
|
||||
ModifiedAt: m.fi.ModTime(),
|
||||
Model: n.DisplayShortest(),
|
||||
Name: n.DisplayShortest(),
|
||||
RemoteModel: cf.RemoteModel,
|
||||
RemoteHost: cf.RemoteHost,
|
||||
Size: m.Size(),
|
||||
Digest: m.digest,
|
||||
ModifiedAt: m.fi.ModTime(),
|
||||
Details: api.ModelDetails{
|
||||
Format: cf.ModelFormat,
|
||||
Family: cf.ModelFamily,
|
||||
@@ -1301,6 +1410,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/show", s.ShowHandler)
|
||||
r.DELETE("/api/delete", s.DeleteHandler)
|
||||
|
||||
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
|
||||
r.POST("/api/me", s.WhoamiHandler)
|
||||
|
||||
// Create
|
||||
r.POST("/api/create", s.CreateHandler)
|
||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||
@@ -1497,6 +1609,49 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) WhoamiHandler(c *gin.Context) {
|
||||
// todo allow other hosts
|
||||
u, err := url.Parse("https://ollama.com")
|
||||
if err != nil {
|
||||
slog.Error(err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
|
||||
return
|
||||
}
|
||||
|
||||
client := api.NewClient(u, http.DefaultClient)
|
||||
user, err := client.Whoami(c)
|
||||
if err != nil {
|
||||
slog.Error(err.Error())
|
||||
}
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
func (s *Server) SignoutHandler(c *gin.Context) {
|
||||
encodedKey := c.Param("encodedKey")
|
||||
|
||||
// todo allow other hosts
|
||||
u, err := url.Parse("https://ollama.com")
|
||||
if err != nil {
|
||||
slog.Error(err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
|
||||
return
|
||||
}
|
||||
|
||||
client := api.NewClient(u, http.DefaultClient)
|
||||
err = client.Signout(c, encodedKey)
|
||||
if err != nil {
|
||||
slog.Error(err.Error())
|
||||
if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (s *Server) PsHandler(c *gin.Context) {
|
||||
models := []api.ProcessModelResponse{}
|
||||
|
||||
@@ -1553,21 +1708,34 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
case err.Error() == errtypes.InvalidModelNameErrMsg:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
case err.Error() == errtypes.InvalidModelNameErrMsg:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
s.sched.expireRunner(model)
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
|
||||
c.JSON(http.StatusOK, api.ChatResponse{
|
||||
Model: req.Model,
|
||||
@@ -1579,6 +1747,66 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
origModel := req.Model
|
||||
|
||||
remoteURL, err := url.Parse(m.Config.RemoteHost)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
|
||||
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
|
||||
return
|
||||
}
|
||||
|
||||
req.Model = m.Config.RemoteModel
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
|
||||
msgs := append(m.Messages, req.Messages...)
|
||||
if req.Messages[0].Role != "system" && m.System != "" {
|
||||
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
||||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
req.Messages = msgs
|
||||
|
||||
for k, v := range m.Options {
|
||||
if _, ok := req.Options[k]; !ok {
|
||||
req.Options[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
fn := func(resp api.ChatResponse) error {
|
||||
resp.Model = origModel
|
||||
resp.RemoteModel = m.Config.RemoteModel
|
||||
resp.RemoteHost = m.Config.RemoteHost
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Writer.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||
err = client.Chat(c, &req, fn)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
caps := []model.Capability{model.CapabilityCompletion}
|
||||
if len(req.Tools) > 0 {
|
||||
caps = append(caps, model.CapabilityTools)
|
||||
@@ -1587,17 +1815,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var stream bool = false
|
||||
@@ -615,6 +617,78 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateAndShowRemoteModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
From: "bob",
|
||||
RemoteHost: "https://ollama.com",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion", "tools", "thinking"},
|
||||
"model_family": "gptoss",
|
||||
"context_length": 131072,
|
||||
"embedding_length": 2880,
|
||||
"quantization_level": "MXFP4",
|
||||
"parameter_size": "20.9B",
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("exected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("exected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ShowResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedDetails := api.ModelDetails{
|
||||
ParentModel: "",
|
||||
Format: "",
|
||||
Family: "gptoss",
|
||||
Families: []string{"gptoss"},
|
||||
ParameterSize: "20.9B",
|
||||
QuantizationLevel: "MXFP4",
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(resp.Details, expectedDetails) {
|
||||
t.Errorf("model details: expected %#v, actual %#v", expectedDetails, resp.Details)
|
||||
}
|
||||
|
||||
expectedCaps := []model.Capability{
|
||||
model.Capability("completion"),
|
||||
model.Capability("tools"),
|
||||
model.Capability("thinking"),
|
||||
}
|
||||
|
||||
if !slices.Equal(resp.Capabilities, expectedCaps) {
|
||||
t.Errorf("capabilities: expected %#v, actual %#v", expectedCaps, resp.Capabilities)
|
||||
}
|
||||
|
||||
v, ok := resp.ModelInfo["gptoss.context_length"]
|
||||
ctxlen := v.(float64)
|
||||
if !ok || int(ctxlen) != 131072 {
|
||||
t.Errorf("context len: expected %d, actual %d", 131072, int(ctxlen))
|
||||
}
|
||||
|
||||
v, ok = resp.ModelInfo["gptoss.embedding_length"]
|
||||
embedlen := v.(float64)
|
||||
if !ok || int(embedlen) != 2880 {
|
||||
t.Errorf("embed len: expected %d, actual %d", 2880, int(embedlen))
|
||||
}
|
||||
|
||||
fmt.Printf("resp = %#v\n", resp)
|
||||
}
|
||||
|
||||
func TestCreateLicenses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -126,7 +126,15 @@ func TestRoutes(t *testing.T) {
|
||||
t.Fatalf("failed to create model: %v", err)
|
||||
}
|
||||
|
||||
if err := createModel(r, modelName, baseLayers, fn); err != nil {
|
||||
config := &ConfigV2{
|
||||
OS: "linux",
|
||||
Architecture: "amd64",
|
||||
RootFS: RootFS{
|
||||
Type: "layers",
|
||||
},
|
||||
}
|
||||
|
||||
if err := createModel(r, modelName, baseLayers, config, fn); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user