mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-22 23:03:55 +00:00
Merge branch 'ollama:main' into main
This commit is contained in:
36
llm/ggml.go
36
llm/ggml.go
@@ -360,7 +360,7 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffload uint64) {
|
||||
func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||
embedding := llm.KV().EmbeddingLength()
|
||||
heads := llm.KV().HeadCount()
|
||||
headsKV := llm.KV().HeadCountKV()
|
||||
@@ -372,7 +372,8 @@ func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffloa
|
||||
|
||||
layers := llm.Tensors().Layers()
|
||||
|
||||
kv = 2 * context * llm.KV().BlockCount() * (embeddingHeadsK + embeddingHeadsV) * headsKV
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
kv = uint64(float64(context*llm.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "llama":
|
||||
@@ -527,3 +528,34 @@ func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffloa
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (ggml GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
validKVCacheTypes := []string{"f16", "q8_0", "q4_0"}
|
||||
return slices.Contains(validKVCacheTypes, cacheType)
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
func (ggml GGML) SupportsFlashAttention() bool {
|
||||
_, isEmbedding := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]
|
||||
if isEmbedding {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := ggml.KV().EmbeddingHeadCountK()
|
||||
headCountV := ggml.KV().EmbeddingHeadCountV()
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
switch cacheType {
|
||||
case "q8_0":
|
||||
return 1 // 1/2 of fp16
|
||||
case "q4_0":
|
||||
return 0.5 // 1/4 of fp16
|
||||
default:
|
||||
return 2 // f16 (default)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,7 +123,23 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
slog.Warn("model missing blk.0 layer size")
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
||||
fa := envconfig.FlashAttention() &&
|
||||
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||
ggml.SupportsFlashAttention()
|
||||
|
||||
var kvct string
|
||||
if fa {
|
||||
requested := strings.ToLower(envconfig.KvCacheType())
|
||||
if requested != "" && ggml.SupportsKVCacheType(requested) {
|
||||
kvct = requested
|
||||
}
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
layerSize += kv / ggml.KV().BlockCount()
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
||||
}
|
||||
@@ -131,9 +147,6 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
graphFullOffload = graphPartialOffload
|
||||
}
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
layerSize += kv / ggml.KV().BlockCount()
|
||||
|
||||
// on metal there's no partial offload overhead
|
||||
if gpus[0].Library == "metal" {
|
||||
graphPartialOffload = graphFullOffload
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
func TestEstimateGPULayers(t *testing.T) {
|
||||
t.Setenv("OLLAMA_DEBUG", "1")
|
||||
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
|
||||
|
||||
modelName := "dummy"
|
||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||
|
||||
@@ -214,15 +214,36 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
params = append(params, "--threads", strconv.Itoa(defaultThreads))
|
||||
}
|
||||
|
||||
flashAttnEnabled := envconfig.FlashAttention()
|
||||
fa := envconfig.FlashAttention()
|
||||
if fa && !gpus.FlashAttentionSupported() {
|
||||
slog.Warn("flash attention enabled but not supported by gpu")
|
||||
fa = false
|
||||
}
|
||||
|
||||
for _, g := range gpus {
|
||||
// only cuda (compute capability 7+) and metal support flash attention
|
||||
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
|
||||
flashAttnEnabled = false
|
||||
if fa && !ggml.SupportsFlashAttention() {
|
||||
slog.Warn("flash attention enabled but not supported by model")
|
||||
fa = false
|
||||
}
|
||||
|
||||
kvct := strings.ToLower(envconfig.KvCacheType())
|
||||
|
||||
if fa {
|
||||
slog.Info("enabling flash attention")
|
||||
params = append(params, "--flash-attn")
|
||||
|
||||
// Flash Attention also supports kv cache quantization
|
||||
// Enable if the requested and kv cache type is supported by the model
|
||||
if kvct != "" && ggml.SupportsKVCacheType(kvct) {
|
||||
params = append(params, "--kv-cache-type", kvct)
|
||||
} else {
|
||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||
}
|
||||
} else if kvct != "" && kvct != "f16" {
|
||||
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
||||
}
|
||||
|
||||
// mmap has issues with partial offloading on metal
|
||||
// mmap has issues with partial offloading on metal
|
||||
for _, g := range gpus {
|
||||
if g.Library == "metal" &&
|
||||
uint64(opts.NumGPU) > 0 &&
|
||||
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
||||
@@ -231,10 +252,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
}
|
||||
}
|
||||
|
||||
if flashAttnEnabled {
|
||||
params = append(params, "--flash-attn")
|
||||
}
|
||||
|
||||
// Windows CUDA should not use mmap for best performance
|
||||
// Linux with a model larger than free space, mmap leads to thrashing
|
||||
// For CPU loads we want the memory to be allocated, not FS cache
|
||||
@@ -617,27 +634,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
const jsonGrammar = `
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
@@ -667,7 +679,7 @@ type completion struct {
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
}
|
||||
@@ -732,10 +744,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
if req.Format == "json" {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
// TODO (parthsareen): Move conversion to grammar with sampling logic
|
||||
// API should do error handling for invalid formats
|
||||
if req.Format != nil && strings.TrimSpace(string(req.Format)) != "null" {
|
||||
if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||
slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
}
|
||||
} else if schema, err := func() (llama.JsonSchema, error) {
|
||||
var schema llama.JsonSchema
|
||||
err := json.Unmarshal(req.Format, &schema)
|
||||
return schema, err
|
||||
}(); err == nil {
|
||||
request["grammar"] = schema.AsGrammar()
|
||||
} else {
|
||||
slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user