chat api endpoint (#1392)

This commit is contained in:
Bruce MacDonald
2023-12-05 14:57:33 -05:00
committed by GitHub
parent 00d06619a1
commit 195e3d9dbd
9 changed files with 550 additions and 132 deletions

View File

@@ -47,37 +47,85 @@ type Model struct {
Options map[string]interface{}
}
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
t := m.Template
if request.Template != "" {
t = request.Template
}
type PromptVars struct {
System string
Prompt string
Response string
First bool
}
tmpl, err := template.New("").Parse(t)
func (m *Model) Prompt(p PromptVars) (string, error) {
var prompt strings.Builder
tmpl, err := template.New("").Parse(m.Template)
if err != nil {
return "", err
}
var vars struct {
First bool
System string
Prompt string
}
vars.First = len(request.Context) == 0
vars.System = m.System
vars.Prompt = request.Prompt
if request.System != "" {
vars.System = request.System
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
}
var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil {
if err := tmpl.Execute(&sb, p); err != nil {
return "", err
}
prompt.WriteString(sb.String())
prompt.WriteString(p.Response)
return prompt.String(), nil
}
return sb.String(), nil
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
// build the prompt from the list of messages
var prompt strings.Builder
currentVars := PromptVars{
First: true,
}
writePrompt := func() error {
p, err := m.Prompt(currentVars)
if err != nil {
return err
}
prompt.WriteString(p)
currentVars = PromptVars{}
return nil
}
for _, msg := range msgs {
switch msg.Role {
case "system":
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
currentVars.System = msg.Content
case "user":
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
currentVars.Prompt = msg.Content
case "assistant":
currentVars.Response = msg.Content
if err := writePrompt(); err != nil {
return "", err
}
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
return prompt.String(), nil
}
type ManifestV2 struct {
@@ -383,7 +431,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
c.Args = blobPath
}
fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(modelFileDir, c.Args))
if err != nil {