convert gemma2

This commit is contained in:
Michael Yang
2024-06-28 13:27:05 -07:00
parent beb49eef65
commit 3546bbd08c
13 changed files with 132 additions and 46 deletions

View File

@@ -7,6 +7,7 @@ import (
"io"
"io/fs"
"log/slog"
"strings"
"github.com/ollama/ollama/llm"
)
@@ -58,11 +59,13 @@ type Converter interface {
KV(*Tokenizer) llm.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []llm.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
// tensorName returns the LLM tensor name for a specific input name
tensorName(string) string
// specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string
// writeFile writes the model to the provided io.WriteSeeker
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
}
@@ -97,6 +100,8 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error {
conv = &mixtral{}
case "GemmaForCausalLM":
conv = &gemma{}
case "Gemma2ForCausalLM":
conv = &gemma2{}
case "Phi3ForCausalLM":
conv = &phi3{}
case "BertModel":
@@ -131,7 +136,7 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error {
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
ts, err := parseTensors(fsys)
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
return err
}