Merge branch 'ollama:main' into main

This commit is contained in:
likelovewant
2024-05-31 18:43:24 +08:00
committed by GitHub
19 changed files with 909 additions and 487 deletions

View File

@@ -34,13 +34,13 @@ jobs:
git diff-tree -r --no-commit-id --name-only \
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
${{ github.event.pull_request.head.sha }} \
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
}
{
echo GENERATE=$(changed llm/)
echo GENERATE_CUDA=$(changed llm/)
echo GENERATE_ROCM=$(changed llm/)
echo GENERATE=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
echo GENERATE_CUDA=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
echo GENERATE_ROCM=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**')
} >>$GITHUB_OUTPUT
generate:
@@ -287,6 +287,8 @@ jobs:
GOARCH: ${{ matrix.arch }}
CGO_ENABLED: '1'
OLLAMA_CPU_TARGET: 'static'
OLLAMA_SKIP_CPU_GENERATE: '1'
OLLAMA_SKIP_METAL_GENERATE: '1'
steps:
- uses: actions/checkout@v4
with:

View File

@@ -755,7 +755,11 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
}
// backtrack the length of the last word and clear to the end of the line
fmt.Printf("\x1b[%dD\x1b[K\n", runewidth.StringWidth(state.wordBuffer))
a := runewidth.StringWidth(state.wordBuffer)
if a > 0 {
fmt.Printf("\x1b[%dD", a)
}
fmt.Printf("\x1b[K\n")
fmt.Printf("%s%c", state.wordBuffer, ch)
chWidth := runewidth.RuneWidth(ch)
@@ -1251,6 +1255,9 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_MAX_VRAM"],
})
default:
appendEnvDocs(cmd, envs)

View File

@@ -76,6 +76,7 @@ Make sure you've set up the container runtime first as described in [docker.md](
Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
- Is the container runtime working? Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
- Is the uvm driver not loaded? `sudo nvidia-modprobe -u`
- Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
- Try rebooting

View File

@@ -51,16 +51,16 @@ func AsMap() map[string]EnvVar {
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"},
"OLLAMA_HOST": {"OLLAMA_HOST", "", "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_ORIGINS", LLMLibrary, ""},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, ""},
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", "", "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, ""},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
}
}

View File

@@ -140,7 +140,6 @@ struct server_slot {
std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs;
bool infill = false;
bool embedding = false;
bool has_next_token = true;
bool truncated = false;
@@ -187,7 +186,6 @@ struct server_slot {
n_past = 0;
n_sent_text = 0;
n_sent_token_probs = 0;
infill = false;
ga_i = 0;
n_past_se = 0;
@@ -600,16 +598,6 @@ struct llama_server_context
slot->params.n_predict = slot->n_predict;
}
// infill
if (data.count("input_prefix") != 0)
{
slot->params.input_prefix = data["input_prefix"];
}
else
{
slot->params.input_prefix = "";
}
if (data.count("input_suffix") != 0)
{
slot->params.input_suffix = data["input_suffix"];
@@ -897,15 +885,6 @@ struct llama_server_context
system_need_update = true;
}
void system_prompt_process(const json &sys_props) {
system_prompt = sys_props.value("prompt", "");
name_user = sys_props.value("anti_prompt", "");
name_assistant = sys_props.value("assistant_name", "");
system_prompt_notify();
}
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
const stop_type type, server_slot &slot)
{
@@ -1263,13 +1242,12 @@ struct llama_server_context
queue_results.send(res);
}
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
void request_completion(int task_id, json data, bool embedding, int multitask_id)
{
task_server task;
task.id = task_id;
task.target_id = 0;
task.data = std::move(data);
task.infill_mode = infill;
task.embedding_mode = embedding;
task.type = TASK_TYPE_COMPLETION;
task.multitask_id = multitask_id;
@@ -1415,8 +1393,8 @@ struct llama_server_context
json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i];
// subtasks inherit everything else (infill mode, embedding mode, etc.)
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
// subtasks inherit everything else (embedding mode, etc.)
request_completion(subtask_ids[i], subtask_data, multiprompt_task.embedding_mode, multitask_id);
}
}
@@ -1434,26 +1412,8 @@ struct llama_server_context
break;
}
if (task.data.contains("system_prompt"))
{
if (!all_slots_are_idle) {
send_error(task, "system prompt can only be updated when all slots are idle");
break;
}
system_prompt_process(task.data["system_prompt"]);
// reset cache_tokens for all slots
for (server_slot &slot : slots)
{
slot.cache_tokens.clear();
slot.n_past = 0;
slot.n_past_se = 0;
}
}
slot->reset();
slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode;
slot->task_id = task.id;
slot->multitask_id = task.multitask_id;
@@ -1679,8 +1639,7 @@ struct llama_server_context
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty();
// empty prompt passed -> release the slot and send empty response
// note: infill mode allows empty prompt
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill)
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt)
{
slot.release();
slot.print_timings();
@@ -1697,33 +1656,7 @@ struct llama_server_context
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_genereration = 0;
if (slot.infill)
{
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1)
{
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
const int space_token = 29871; // TODO: this should not be hardcoded
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(model));
prompt_tokens = prefix_tokens;
}
else
{
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
}
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
slot.n_prompt_tokens = prompt_tokens.size();
@@ -2130,8 +2063,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf("\n");
}
static void server_params_parse(int argc, char **argv, server_params &sparams,
gpt_params &params, llama_server_context& llama)
static void server_params_parse(int argc, char **argv, server_params &sparams, gpt_params &params)
{
gpt_params default_params;
server_params default_sparams;
@@ -2546,27 +2478,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_predict = std::stoi(argv[i]);
}
else if (arg == "-spf" || arg == "--system-prompt-file")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
std::string systm_content;
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(systm_content)
);
llama.system_prompt_process(json::parse(systm_content));
}
else if (arg == "-ctk" || arg == "--cache-type-k") {
params.cache_type_k = argv[++i];
}
@@ -2714,21 +2625,6 @@ static json format_partial_response(
return res;
}
static json format_tokenizer_response(const std::vector<llama_token> &tokens)
{
return json {
{"tokens", tokens}
};
}
static json format_detokenized_response(std::string content)
{
return json {
{"content", content}
};
}
static void log_server_request(const httplib::Request &req, const httplib::Response &res)
{
// skip GH copilot requests when using default port
@@ -2818,7 +2714,7 @@ int main(int argc, char **argv) {
// struct that contains llama context and inference
llama_server_context llama;
server_params_parse(argc, argv, sparams, params, llama);
server_params_parse(argc, argv, sparams, params);
if (params.model_alias == "unknown")
{
@@ -3150,7 +3046,7 @@ int main(int argc, char **argv) {
json data = json::parse(req.body);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
llama.request_completion(task_id, data, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
@@ -3218,34 +3114,6 @@ int main(int argc, char **argv) {
}
});
svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
std::vector<llama_token> tokens;
if (body.count("content") != 0)
{
tokens = llama.tokenize(body["content"], false);
}
const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json; charset=utf-8");
});
svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
std::string content;
if (body.count("tokens") != 0)
{
const std::vector<llama_token> tokens = body["tokens"];
content = tokens_to_str(llama.ctx, tokens.cbegin(), tokens.cend());
}
const json data = format_detokenized_response(content);
return res.set_content(data.dump(), "application/json; charset=utf-8");
});
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
@@ -3272,7 +3140,7 @@ int main(int argc, char **argv) {
// create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
// get the result
task_result result = llama.queue_results.recv(task_id);

View File

@@ -32,42 +32,43 @@ case "${GOARCH}" in
echo "Building static library"
build
if [ -z "$OLLAMA_SKIP_CPU_GENERATE" ]; then
#
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu"
echo "Building LCD CPU"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
#
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu"
echo "Building LCD CPU"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
#
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
# Approximately 400% faster than LCD on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
echo "Building AVX CPU"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
#
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
# Approximately 400% faster than LCD on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
echo "Building AVX CPU"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
#
# ~2013 CPU Dynamic library
# Approximately 10% faster than AVX on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
#
# ~2013 CPU Dynamic library
# Approximately 10% faster than AVX on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
fi
;;
"arm64")
@@ -79,13 +80,15 @@ case "${GOARCH}" in
echo "Building static library"
build
init_vars
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/metal"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
if [ -z "$OLLAMA_SKIP_METAL_GENERATE" ]; then
init_vars
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/metal"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
fi
;;
*)
echo "GOARCH must be set"

View File

@@ -12,6 +12,7 @@ package llm
import "C"
import (
"fmt"
"strings"
"unsafe"
)
@@ -37,3 +38,62 @@ func Quantize(infile, outfile string, ftype fileType) error {
return nil
}
type llamaModel struct {
m *C.struct_llama_model
}
func newLlamaModel(p string) *llamaModel {
cs := C.CString(p)
defer C.free(unsafe.Pointer(cs))
params := C.llama_model_default_params()
params.vocab_only = true
return &llamaModel{
C.llama_load_model_from_file(cs, params),
}
}
func (llm *llamaModel) Close() {
C.llama_free_model(llm.m)
}
func (llm *llamaModel) Tokenize(s string) []int {
cs := C.CString(s)
defer C.free(unsafe.Pointer(cs))
ltokens := make([]C.llama_token, len(s)+2)
n := C.llama_tokenize(
llm.m,
cs,
C.int32_t(len(s)),
&ltokens[0],
C.int32_t(len(ltokens)),
false,
true,
)
if n < 0 {
return nil
}
tokens := make([]int, n)
for i := 0; i < int(n); i++ {
tokens[i] = int(ltokens[i])
}
return tokens
}
func (llm *llamaModel) Detokenize(i32s []int) string {
var sb strings.Builder
for _, i32 := range i32s {
c := make([]byte, 512)
if n := C.llama_token_to_piece(llm.m, C.llama_token(i32), (*C.char)(unsafe.Pointer(&c[0])), C.int(len(c)), false); n > 0 {
sb.WriteString(unsafe.String(&c[0], n))
}
}
return sb.String()
}

View File

@@ -1,35 +1,32 @@
From d02a06f3f45a09255ace8684a66590e06ce44605 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Thu, 23 May 2024 11:33:20 -0700
Subject: [PATCH] default pretokenizer on unrecognized type
---
llama.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/llama.cpp b/llama.cpp
index 15c66077..af1aede3 100644
index 40d2ec2c..74f3ee9c 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -4504,9 +4504,6 @@ static void llm_load_vocab(
LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
LLAMA_LOG_WARN("%s: \n", __func__);
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
- } else if (
- tokenizer_pre == "default") {
@@ -4642,16 +4642,7 @@ static void llm_load_vocab(
// for now, only BPE models have pre-tokenizers
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
- if (tokenizer_pre.empty()) {
- LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__);
- LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
- LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__);
- LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__);
- LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__);
- vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
- } else if (
+ if (
tokenizer_pre == "default") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if (
tokenizer_pre == "llama3" ||
tokenizer_pre == "llama-v3" ||
@@ -4553,7 +4550,7 @@ static void llm_load_vocab(
tokenizer_pre == "dbrx") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
@@ -4703,7 +4694,8 @@ static void llm_load_vocab(
tokenizer_pre == "smaug-bpe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
} else {
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
}
} else {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
--
2.45.1

View File

@@ -57,6 +57,8 @@ type llmServer struct {
loadDuration time.Duration // Record how long it took the model to load
loadProgress float32
*llamaModel
sem *semaphore.Weighted
}
@@ -189,35 +191,38 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--memory-f32")
}
if opts.UseMLock {
params = append(params, "--mlock")
flashAttnEnabled := envconfig.FlashAttention
for _, g := range gpus {
// only cuda (compute capability 7+) and metal support flash attention
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnEnabled = false
}
// mmap has issues with partial offloading on metal
if g.Library == "metal" &&
uint64(opts.NumGPU) > 0 &&
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
opts.UseMMap = false
}
}
if flashAttnEnabled {
params = append(params, "--flash-attn")
}
if !opts.UseMMap {
params = append(params, "--no-mmap")
}
if opts.UseMLock {
params = append(params, "--mlock")
}
if opts.UseNUMA {
params = append(params, "--numa")
}
flashAttnEnabled := envconfig.FlashAttention
// partial offloading does not support flash attention
if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
flashAttnEnabled = false
}
// only cuda (compute capability 7+) and metal support flash attention
for _, g := range gpus {
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnEnabled = false
}
}
if flashAttnEnabled {
params = append(params, "--flash-attn")
}
numParallel := envconfig.NumParallel
// TODO (jmorganca): multimodal models don't support parallel yet
@@ -306,6 +311,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
totalLayers: ggml.KV().BlockCount() + 1,
gpuCount: gpuCount,
done: make(chan error, 1),
llamaModel: newLlamaModel(model),
}
s.cmd.Env = os.Environ()
@@ -843,12 +849,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: prompt})
if err != nil {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(EmbeddingRequest{Content: prompt}); err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), &b)
if err != nil {
return nil, fmt.Errorf("error creating embed request: %w", err)
}
@@ -878,108 +884,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return embedding.Embedding, nil
}
type TokenizeRequest struct {
Content string `json:"content"`
}
type TokenizeResponse struct {
Tokens []int `json:"tokens"`
}
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return nil, err
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: content})
if err != nil {
return nil, fmt.Errorf("marshaling encode data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("encode request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("do encode request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read encode request: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm encode error: %s", body)
return nil, fmt.Errorf("%s", body)
}
var encoded TokenizeResponse
if err := json.Unmarshal(body, &encoded); err != nil {
return nil, fmt.Errorf("unmarshal encode response: %w", err)
}
return encoded.Tokens, nil
}
type DetokenizeRequest struct {
Tokens []int `json:"tokens"`
}
type DetokenizeResponse struct {
Content string `json:"content"`
return s.llamaModel.Tokenize(content), nil
}
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return "", err
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
if err != nil {
return "", fmt.Errorf("marshaling decode data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
if err != nil {
return "", fmt.Errorf("decode request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("do decode request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read decode request: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm decode error: %s", body)
return "", fmt.Errorf("%s", body)
}
var decoded DetokenizeResponse
if err := json.Unmarshal(body, &decoded); err != nil {
return "", fmt.Errorf("unmarshal encode response: %w", err)
}
return decoded.Content, nil
return s.llamaModel.Detokenize(tokens), nil
}
func (s *llmServer) Close() error {
@@ -997,6 +907,10 @@ func (s *llmServer) Close() error {
slog.Debug("llama server stopped")
}
if s.llamaModel != nil {
s.llamaModel.Close()
}
return nil
}

View File

@@ -771,37 +771,6 @@ func PruneDirectory(path string) error {
return nil
}
func DeleteModel(name string) error {
mp := ParseModelPath(name)
manifest, _, err := GetManifest(mp)
if err != nil {
return err
}
deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
err = os.Remove(fp)
if err != nil {
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
return err
}
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})

View File

@@ -88,3 +88,26 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
return os.Open(blob)
}
func (l *Layer) Remove() error {
ms, err := Manifests()
if err != nil {
return err
}
for _, m := range ms {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == l.Digest {
// something is using this layer
return nil
}
}
}
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return err
}
return os.Remove(blob)
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
@@ -14,7 +15,10 @@ import (
type Manifest struct {
ManifestV2
Digest string `json:"-"`
filepath string
fi os.FileInfo
digest string
}
func (m *Manifest) Size() (size int64) {
@@ -25,9 +29,28 @@ func (m *Manifest) Size() (size int64) {
return
}
func ParseNamedManifest(name model.Name) (*Manifest, error) {
if !name.IsFullyQualified() {
return nil, model.Unqualified(name)
func (m *Manifest) Remove() error {
if err := os.Remove(m.filepath); err != nil {
return err
}
for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); err != nil {
return err
}
}
manifests, err := GetManifestPath()
if err != nil {
return err
}
return PruneDirectory(manifests)
}
func ParseNamedManifest(n model.Name) (*Manifest, error) {
if !n.IsFullyQualified() {
return nil, model.Unqualified(n)
}
manifests, err := GetManifestPath()
@@ -35,20 +58,30 @@ func ParseNamedManifest(name model.Name) (*Manifest, error) {
return nil, err
}
var manifest ManifestV2
manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
p := filepath.Join(manifests, n.Filepath())
var m ManifestV2
f, err := os.Open(p)
if err != nil {
return nil, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return nil, err
}
sha256sum := sha256.New()
if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil {
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
return nil, err
}
return &Manifest{
ManifestV2: manifest,
Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
ManifestV2: m,
filepath: p,
fi: fi,
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
}, nil
}
@@ -77,3 +110,48 @@ func WriteManifest(name string, config *Layer, layers []*Layer) error {
return os.WriteFile(manifestPath, b.Bytes(), 0o644)
}
func Manifests() (map[model.Name]*Manifest, error) {
manifests, err := GetManifestPath()
if err != nil {
return nil, err
}
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
if err != nil {
return nil, err
}
ms := make(map[model.Name]*Manifest)
for _, match := range matches {
fi, err := os.Stat(match)
if err != nil {
return nil, err
}
if !fi.IsDir() {
rel, err := filepath.Rel(manifests, match)
if err != nil {
slog.Warn("bad filepath", "path", match, "error", err)
continue
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest name", "path", rel, "error", err)
continue
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
continue
}
ms[n] = m
}
}
return ms, nil
}

150
server/manifest_test.go Normal file
View File

@@ -0,0 +1,150 @@
package server
import (
"encoding/json"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/types/model"
)
func createManifest(t *testing.T, path, name string) {
t.Helper()
p := filepath.Join(path, "manifests", name)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
t.Fatal(err)
}
f, err := os.Create(p)
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
t.Fatal(err)
}
}
func TestManifests(t *testing.T) {
cases := map[string]struct {
ps []string
wantValidCount int
wantInvalidCount int
}{
"empty": {},
"single": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag"),
},
wantValidCount: 1,
},
"multiple": {
ps: []string{
filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
},
wantValidCount: 15,
},
"hidden": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag"),
filepath.Join("host", "namespace", "model", ".hidden"),
},
wantValidCount: 1,
wantInvalidCount: 1,
},
"subdir": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag", "one"),
filepath.Join("host", "namespace", "model", "tag", "another", "one"),
},
wantInvalidCount: 2,
},
"upper tag": {
ps: []string{
filepath.Join("host", "namespace", "model", "TAG"),
},
wantValidCount: 1,
},
"upper model": {
ps: []string{
filepath.Join("host", "namespace", "MODEL", "tag"),
},
wantValidCount: 1,
},
"upper namespace": {
ps: []string{
filepath.Join("host", "NAMESPACE", "model", "tag"),
},
wantValidCount: 1,
},
"upper host": {
ps: []string{
filepath.Join("HOST", "namespace", "model", "tag"),
},
wantValidCount: 1,
},
}
for n, wants := range cases {
t.Run(n, func(t *testing.T) {
d := t.TempDir()
t.Setenv("OLLAMA_MODELS", d)
for _, p := range wants.ps {
createManifest(t, d, p)
}
ms, err := Manifests()
if err != nil {
t.Fatal(err)
}
var ns []model.Name
for k := range ms {
ns = append(ns, k)
}
var gotValidCount, gotInvalidCount int
for _, p := range wants.ps {
n := model.ParseNameFromFilepath(p)
if n.IsValid() {
gotValidCount++
} else {
gotInvalidCount++
}
if !n.IsValid() && slices.Contains(ns, n) {
t.Errorf("unexpected invalid name: %s", p)
} else if n.IsValid() && !slices.Contains(ns, n) {
t.Errorf("missing valid name: %s", p)
}
}
if gotValidCount != wants.wantValidCount {
t.Errorf("got valid count %d, want %d", gotValidCount, wants.wantValidCount)
}
if gotInvalidCount != wants.wantInvalidCount {
t.Errorf("got invalid count %d, want %d", gotInvalidCount, wants.wantInvalidCount)
}
})
}
}

View File

@@ -421,13 +421,14 @@ func (s *Server) PullModelHandler(c *gin.Context) {
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
name := model.ParseName(cmp.Or(req.Model, req.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
return
}
if err := checkNameExists(name); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -445,7 +446,7 @@ func (s *Server) PullModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PullModel(ctx, model, regOpts, fn); err != nil {
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -507,6 +508,21 @@ func (s *Server) PushModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
func checkNameExists(name model.Name) error {
names, err := Manifests()
if err != nil {
return err
}
for n := range names {
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
return fmt.Errorf("a model with that name already exists")
}
}
return nil
}
func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
@@ -523,6 +539,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
return
}
if err := checkNameExists(name); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return
@@ -575,48 +596,31 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
}
func (s *Server) DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.DeleteRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}
if err := DeleteModel(model); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
manifestsPath, err := GetManifestPath()
m, err := ParseNamedManifest(n)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := PruneDirectory(manifestsPath); err != nil {
if err := m.Remove(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, nil)
}
func (s *Server) ShowModelHandler(c *gin.Context) {
@@ -720,72 +724,42 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
func (s *Server) ListModelsHandler(c *gin.Context) {
manifests, err := GetManifestPath()
ms, err := Manifests()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
models := []api.ModelResponse{}
if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
if !info.IsDir() {
rel, err := filepath.Rel(manifests, path)
if err != nil {
return err
}
for n, m := range ms {
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest filepath", "name", n, "error", err)
continue
}
defer f.Close()
if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
return err
} else if hidden {
return nil
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest filepath", "path", rel)
return nil
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
return nil
}
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest config filepath", "name", n, "error", err)
return nil
}
defer f.Close()
var c ConfigV2
if err := json.NewDecoder(f).Decode(&c); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
return nil
}
// tag should never be masked
models = append(models, api.ModelResponse{
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.Digest,
ModifiedAt: info.ModTime(),
Details: api.ModelDetails{
Format: c.ModelFormat,
Family: c.ModelFamily,
Families: c.ModelFamilies,
ParameterSize: c.ModelType,
QuantizationLevel: c.FileType,
},
})
var cf ConfigV2
if err := json.NewDecoder(f).Decode(&cf); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
continue
}
return nil
}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
// tag should never be masked
models = append(models, api.ModelResponse{
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
Families: cf.ModelFamilies,
ParameterSize: cf.ModelType,
QuantizationLevel: cf.FileType,
},
})
}
slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
@@ -818,6 +792,11 @@ func (s *Server) CopyModelHandler(c *gin.Context) {
return
}
if err := checkNameExists(dst); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
} else if err != nil {

View File

@@ -0,0 +1,160 @@
package server
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"slices"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
)
var stream bool = false
func createBinFile(t *testing.T) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
t.Fatal(err)
}
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
t.Fatal(err)
}
if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
t.Fatal(err)
}
if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
t.Fatal(err)
}
return f.Name()
}
type responseRecorder struct {
*httptest.ResponseRecorder
http.CloseNotifier
}
func NewRecorder() *responseRecorder {
return &responseRecorder{
ResponseRecorder: httptest.NewRecorder(),
}
}
func (t *responseRecorder) CloseNotify() <-chan bool {
return make(chan bool)
}
func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.ResponseRecorder {
t.Helper()
w := NewRecorder()
c, _ := gin.CreateTestContext(w)
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(body); err != nil {
t.Fatal(err)
}
c.Request = &http.Request{
Body: io.NopCloser(&b),
}
fn(c)
return w.ResponseRecorder
}
func checkFileExists(t *testing.T, p string, expect []string) {
t.Helper()
actual, err := filepath.Glob(p)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(actual, expect) {
t.Fatalf("expected slices to be equal %v", actual)
}
}
func TestCreateFromBin(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
}
func TestCreateFromModel(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: "FROM test",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
}

View File

@@ -0,0 +1,71 @@
package server
import (
"fmt"
"net/http"
"path/filepath"
"testing"
"github.com/ollama/ollama/api"
)
func TestDelete(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)),
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test2"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
}

View File

@@ -0,0 +1,61 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"slices"
"testing"
"github.com/ollama/ollama/api"
)
func TestList(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
expectNames := []string{
"mistral:7b-instruct-q4_0",
"zephyr:7b-beta-q5_K_M",
"apple/OpenELM:latest",
"boreas:2b-code-v1.5-q6_K",
"notus:7b-v1-IQ2_S",
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/library/eurus:700b-v0.5-iq3_XXS",
"mynamespace/apeliotes:latest",
"myhost/mynamespace/lips:code",
}
var s Server
for _, n := range expectNames {
createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: n,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
})
}
w := createRequest(t, s.ListModelsHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
var resp api.ListResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if len(resp.Models) != len(expectNames) {
t.Fatalf("expected %d models, actual %d", len(expectNames), len(resp.Models))
}
actualNames := make([]string, len(resp.Models))
for i, m := range resp.Models {
actualNames[i] = m.Name
}
slices.Sort(actualNames)
slices.Sort(expectNames)
if !slices.Equal(actualNames, expectNames) {
t.Fatalf("expected slices to be equal %v", actualNames)
}
}

View File

@@ -21,6 +21,28 @@ import (
"github.com/ollama/ollama/version"
)
func createTestFile(t *testing.T, name string) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), name)
assert.Nil(t, err)
defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint32(3))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
return f.Name()
}
func Test_Routes(t *testing.T) {
type testCase struct {
Name string
@@ -30,28 +52,6 @@ func Test_Routes(t *testing.T) {
Expected func(t *testing.T, resp *http.Response)
}
createTestFile := func(t *testing.T, name string) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), name)
assert.Nil(t, err)
defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint32(3))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
return f.Name()
}
createTestModel := func(t *testing.T, name string) {
fname := createTestFile(t, "ollama-model")
@@ -237,3 +237,82 @@ func Test_Routes(t *testing.T) {
})
}
}
func TestCase(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
cases := []string{
"mistral",
"llama3:latest",
"library/phi3:q4_0",
"registry.ollama.ai/library/gemma:q5_K_M",
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/alice/bob:latest",
}
var s Server
for _, tt := range cases {
t.Run(tt, func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: tt,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200 got %d", w.Code)
}
expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
if err != nil {
t.Fatal(err)
}
t.Run("create", func(t *testing.T) {
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: strings.ToUpper(tt),
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("pull", func(t *testing.T) {
w := createRequest(t, s.PullModelHandler, api.PullRequest{
Name: strings.ToUpper(tt),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("copy", func(t *testing.T) {
w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
Source: tt,
Destination: strings.ToUpper(tt),
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
})
}
}