diff --git a/parser/parser.go b/parser/parser.go index c2e8f981..7d52c338 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -260,10 +260,13 @@ func filesForModel(path string) ([]string, error) { var files []string // some safetensors files do not properly match "application/octet-stream", so skip checking their contentType - if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 { + if st, _ := glob(filepath.Join(path, "model*.safetensors"), ""); len(st) > 0 { // safetensors files might be unresolved git lfs references; skip if they are // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors files = append(files, st...) + } else if st, _ := glob(filepath.Join(path, "consolidated*.safetensors"), ""); len(st) > 0 { + // covers consolidated.safetensors + files = append(files, st...) } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { // pytorch files might also be unresolved git lfs references; skip if they are // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin diff --git a/parser/parser_test.go b/parser/parser_test.go index 1e1fc452..3300aad3 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -9,6 +9,7 @@ import ( "io" "maps" "os" + "path/filepath" "strings" "testing" "unicode/utf16" @@ -855,3 +856,273 @@ func TestCreateRequestFiles(t *testing.T) { } } } + +func TestFilesForModel(t *testing.T) { + tests := []struct { + name string + setup func(string) error + wantFiles []string + wantErr bool + expectErrType error + }{ + { + name: "safetensors model files", + setup: func(dir string) error { + files := []string{ + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + "config.json", + "tokenizer.json", + } + for _, file := range files { + if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { + return err + } + } + return nil + }, + wantFiles: []string{ + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + "config.json", + "tokenizer.json", + }, + }, + { + name: "safetensors with consolidated files - prefers model files", + setup: func(dir string) error { + files := []string{ + "model-00001-of-00001.safetensors", + "consolidated.safetensors", + "config.json", + } + for _, file := range files { + if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { + return err + } + } + return nil + }, + wantFiles: []string{ + "model-00001-of-00001.safetensors", // consolidated files should be excluded + "config.json", + }, + }, + { + name: "safetensors without model-.safetensors files - uses consolidated", + setup: func(dir string) error { + files := []string{ + "consolidated.safetensors", + "config.json", + } + for _, file := range files { + if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil { + return err + } + } + return nil + }, + wantFiles: []string{ + "consolidated.safetensors", + "config.json", + }, + }, + { + name: "pytorch model files", + setup: func(dir string) error { + // Create a file that will be detected as application/zip + zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} // PK zip header + files := []string{ + "pytorch_model-00001-of-00002.bin", + "pytorch_model-00002-of-00002.bin", + "config.json", + } + for _, file := range files { + content := zipHeader + if file == "config.json" { + content = []byte(`{"config": true}`) + } + if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { + return err + } + } + return nil + }, + wantFiles: []string{ + "pytorch_model-00001-of-00002.bin", + "pytorch_model-00002-of-00002.bin", + "config.json", + }, + }, + { + name: "consolidated pth files", + setup: func(dir string) error { + zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} + files := []string{ + "consolidated.00.pth", + "consolidated.01.pth", + "config.json", + } + for _, file := range files { + content := zipHeader + if file == "config.json" { + content = []byte(`{"config": true}`) + } + if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { + return err + } + } + return nil + }, + wantFiles: []string{ + "consolidated.00.pth", + "consolidated.01.pth", + "config.json", + }, + }, + { + name: "gguf files", + setup: func(dir string) error { + // Create binary content that will be detected as application/octet-stream + binaryContent := make([]byte, 512) + for i := range binaryContent { + binaryContent[i] = byte(i % 256) + } + files := []string{ + "model.gguf", + "config.json", + } + for _, file := range files { + content := binaryContent + if file == "config.json" { + content = []byte(`{"config": true}`) + } + if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { + return err + } + } + return nil + }, + wantFiles: []string{ + "model.gguf", + "config.json", + }, + }, + { + name: "bin files as gguf", + setup: func(dir string) error { + binaryContent := make([]byte, 512) + for i := range binaryContent { + binaryContent[i] = byte(i % 256) + } + files := []string{ + "model.bin", + "config.json", + } + for _, file := range files { + content := binaryContent + if file == "config.json" { + content = []byte(`{"config": true}`) + } + if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { + return err + } + } + return nil + }, + wantFiles: []string{ + "model.bin", + "config.json", + }, + }, + { + name: "no model files found", + setup: func(dir string) error { + // Only create non-model files + files := []string{"README.md", "config.json"} + for _, file := range files { + if err := os.WriteFile(filepath.Join(dir, file), []byte("content"), 0o644); err != nil { + return err + } + } + return nil + }, + wantErr: true, + expectErrType: ErrModelNotFound, + }, + { + name: "invalid content type for pytorch model", + setup: func(dir string) error { + // Create pytorch model file with wrong content type (text instead of zip) + files := []string{ + "pytorch_model.bin", + "config.json", + } + for _, file := range files { + content := []byte("plain text content") + if err := os.WriteFile(filepath.Join(dir, file), content, 0o644); err != nil { + return err + } + } + return nil + }, + wantErr: true, + }, + } + + tmpDir := t.TempDir() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testDir := filepath.Join(tmpDir, tt.name) + if err := os.MkdirAll(testDir, 0o755); err != nil { + t.Fatalf("Failed to create test directory: %v", err) + } + + if err := tt.setup(testDir); err != nil { + t.Fatalf("Setup failed: %v", err) + } + + files, err := filesForModel(testDir) + + if tt.wantErr { + if err == nil { + t.Error("Expected error, but got none") + } + if tt.expectErrType != nil && err != tt.expectErrType { + t.Errorf("Expected error type %v, got %v", tt.expectErrType, err) + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + var relativeFiles []string + for _, file := range files { + rel, err := filepath.Rel(testDir, file) + if err != nil { + t.Fatalf("Failed to get relative path: %v", err) + } + relativeFiles = append(relativeFiles, rel) + } + + if len(relativeFiles) != len(tt.wantFiles) { + t.Errorf("Expected %d files, got %d: %v", len(tt.wantFiles), len(relativeFiles), relativeFiles) + } + + fileSet := make(map[string]bool) + for _, file := range relativeFiles { + fileSet[file] = true + } + + for _, wantFile := range tt.wantFiles { + if !fileSet[wantFile] { + t.Errorf("Missing expected file: %s", wantFile) + } + } + }) + } +}