mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
bugfix: don't include both consolidated.safetensors and model-*.safetensors (#13010)
This commit is contained in:
@@ -260,10 +260,13 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
|
|
||||||
var files []string
|
var files []string
|
||||||
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
|
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
|
||||||
if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 {
|
if st, _ := glob(filepath.Join(path, "model*.safetensors"), ""); len(st) > 0 {
|
||||||
// safetensors files might be unresolved git lfs references; skip if they are
|
// safetensors files might be unresolved git lfs references; skip if they are
|
||||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||||
files = append(files, st...)
|
files = append(files, st...)
|
||||||
|
} else if st, _ := glob(filepath.Join(path, "consolidated*.safetensors"), ""); len(st) > 0 {
|
||||||
|
// covers consolidated.safetensors
|
||||||
|
files = append(files, st...)
|
||||||
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
||||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||||
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"maps"
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"unicode/utf16"
|
"unicode/utf16"
|
||||||
@@ -855,3 +856,273 @@ func TestCreateRequestFiles(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFilesForModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func(string) error
|
||||||
|
wantFiles []string
|
||||||
|
wantErr bool
|
||||||
|
expectErrType error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "safetensors model files",
|
||||||
|
setup: func(dir string) error {
|
||||||
|
files := []string{
|
||||||
|
"model-00001-of-00002.safetensors",
|
||||||
|
"model-00002-of-00002.safetensors",
|
||||||
|
"config.json",
|
||||||
|
"tokenizer.json",
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantFiles: []string{
|
||||||
|
"model-00001-of-00002.safetensors",
|
||||||
|
"model-00002-of-00002.safetensors",
|
||||||
|
"config.json",
|
||||||
|
"tokenizer.json",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "safetensors with consolidated files - prefers model files",
|
||||||
|
setup: func(dir string) error {
|
||||||
|
files := []string{
|
||||||
|
"model-00001-of-00001.safetensors",
|
||||||
|
"consolidated.safetensors",
|
||||||
|
"config.json",
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantFiles: []string{
|
||||||
|
"model-00001-of-00001.safetensors", // consolidated files should be excluded
|
||||||
|
"config.json",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "safetensors without model-.safetensors files - uses consolidated",
|
||||||
|
setup: func(dir string) error {
|
||||||
|
files := []string{
|
||||||
|
"consolidated.safetensors",
|
||||||
|
"config.json",
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantFiles: []string{
|
||||||
|
"consolidated.safetensors",
|
||||||
|
"config.json",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pytorch model files",
|
||||||
|
setup: func(dir string) error {
|
||||||
|
// Create a file that will be detected as application/zip
|
||||||
|
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} // PK zip header
|
||||||
|
files := []string{
|
||||||
|
"pytorch_model-00001-of-00002.bin",
|
||||||
|
"pytorch_model-00002-of-00002.bin",
|
||||||
|
"config.json",
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
content := zipHeader
|
||||||
|
if file == "config.json" {
|
||||||
|
content = []byte(`{"config": true}`)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantFiles: []string{
|
||||||
|
"pytorch_model-00001-of-00002.bin",
|
||||||
|
"pytorch_model-00002-of-00002.bin",
|
||||||
|
"config.json",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "consolidated pth files",
|
||||||
|
setup: func(dir string) error {
|
||||||
|
zipHeader := []byte{0x50, 0x4B, 0x03, 0x04}
|
||||||
|
files := []string{
|
||||||
|
"consolidated.00.pth",
|
||||||
|
"consolidated.01.pth",
|
||||||
|
"config.json",
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
content := zipHeader
|
||||||
|
if file == "config.json" {
|
||||||
|
content = []byte(`{"config": true}`)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantFiles: []string{
|
||||||
|
"consolidated.00.pth",
|
||||||
|
"consolidated.01.pth",
|
||||||
|
"config.json",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gguf files",
|
||||||
|
setup: func(dir string) error {
|
||||||
|
// Create binary content that will be detected as application/octet-stream
|
||||||
|
binaryContent := make([]byte, 512)
|
||||||
|
for i := range binaryContent {
|
||||||
|
binaryContent[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
files := []string{
|
||||||
|
"model.gguf",
|
||||||
|
"config.json",
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
content := binaryContent
|
||||||
|
if file == "config.json" {
|
||||||
|
content = []byte(`{"config": true}`)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantFiles: []string{
|
||||||
|
"model.gguf",
|
||||||
|
"config.json",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bin files as gguf",
|
||||||
|
setup: func(dir string) error {
|
||||||
|
binaryContent := make([]byte, 512)
|
||||||
|
for i := range binaryContent {
|
||||||
|
binaryContent[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
files := []string{
|
||||||
|
"model.bin",
|
||||||
|
"config.json",
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
content := binaryContent
|
||||||
|
if file == "config.json" {
|
||||||
|
content = []byte(`{"config": true}`)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantFiles: []string{
|
||||||
|
"model.bin",
|
||||||
|
"config.json",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no model files found",
|
||||||
|
setup: func(dir string) error {
|
||||||
|
// Only create non-model files
|
||||||
|
files := []string{"README.md", "config.json"}
|
||||||
|
for _, file := range files {
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, file), []byte("content"), 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
expectErrType: ErrModelNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid content type for pytorch model",
|
||||||
|
setup: func(dir string) error {
|
||||||
|
// Create pytorch model file with wrong content type (text instead of zip)
|
||||||
|
files := []string{
|
||||||
|
"pytorch_model.bin",
|
||||||
|
"config.json",
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
content := []byte("plain text content")
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
testDir := filepath.Join(tmpDir, tt.name)
|
||||||
|
if err := os.MkdirAll(testDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("Failed to create test directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tt.setup(testDir); err != nil {
|
||||||
|
t.Fatalf("Setup failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
files, err := filesForModel(testDir)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error, but got none")
|
||||||
|
}
|
||||||
|
if tt.expectErrType != nil && err != tt.expectErrType {
|
||||||
|
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var relativeFiles []string
|
||||||
|
for _, file := range files {
|
||||||
|
rel, err := filepath.Rel(testDir, file)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get relative path: %v", err)
|
||||||
|
}
|
||||||
|
relativeFiles = append(relativeFiles, rel)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(relativeFiles) != len(tt.wantFiles) {
|
||||||
|
t.Errorf("Expected %d files, got %d: %v", len(tt.wantFiles), len(relativeFiles), relativeFiles)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileSet := make(map[string]bool)
|
||||||
|
for _, file := range relativeFiles {
|
||||||
|
fileSet[file] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, wantFile := range tt.wantFiles {
|
||||||
|
if !fileSet[wantFile] {
|
||||||
|
t.Errorf("Missing expected file: %s", wantFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user