continue conversation

feed responses back into the llm
This commit is contained in:
Michael Yang
2023-07-13 11:02:53 -07:00
parent 77dc1a6d74
commit 1775647f76
11 changed files with 47 additions and 10 deletions

View File

@@ -149,9 +149,14 @@ func (llm *llama) Close() {
C.llama_print_timings(llm.ctx)
}
func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
if tokens := llm.tokenize(prompt); tokens != nil {
return llm.generate(tokens, fn)
func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
if input := llm.tokenize(prompt); input != nil {
embd := make([]C.llama_token, len(ctx))
for i := range ctx {
embd[i] = C.llama_token(ctx[i])
}
return llm.generate(append(embd, input...), fn)
}
return errors.New("llama: tokenize")
@@ -194,6 +199,11 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
output := deque[C.llama_token]{capacity: llm.NumCtx}
context := deque[int]{capacity: llm.NumCtx / 2}
for _, in := range input {
context.PushLeft(int(in))
}
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
return errors.New("llama: eval")
@@ -212,6 +222,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
})
output.PushLeft(token)
context.PushLeft(int(token))
input = []C.llama_token{token}
}
@@ -228,6 +239,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
timings := C.llama_get_timings(llm.ctx)
fn(api.GenerateResponse{
Done: true,
Context: context.Data(),
PromptEvalCount: int(timings.n_p_eval),
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
EvalCount: int(timings.n_eval),