mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
continue conversation
feed responses back into the llm
This commit is contained in:
@@ -149,9 +149,14 @@ func (llm *llama) Close() {
|
||||
C.llama_print_timings(llm.ctx)
|
||||
}
|
||||
|
||||
func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
|
||||
if tokens := llm.tokenize(prompt); tokens != nil {
|
||||
return llm.generate(tokens, fn)
|
||||
func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||
if input := llm.tokenize(prompt); input != nil {
|
||||
embd := make([]C.llama_token, len(ctx))
|
||||
for i := range ctx {
|
||||
embd[i] = C.llama_token(ctx[i])
|
||||
}
|
||||
|
||||
return llm.generate(append(embd, input...), fn)
|
||||
}
|
||||
|
||||
return errors.New("llama: tokenize")
|
||||
@@ -194,6 +199,11 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
|
||||
|
||||
output := deque[C.llama_token]{capacity: llm.NumCtx}
|
||||
|
||||
context := deque[int]{capacity: llm.NumCtx / 2}
|
||||
for _, in := range input {
|
||||
context.PushLeft(int(in))
|
||||
}
|
||||
|
||||
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
|
||||
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
|
||||
return errors.New("llama: eval")
|
||||
@@ -212,6 +222,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
|
||||
})
|
||||
|
||||
output.PushLeft(token)
|
||||
context.PushLeft(int(token))
|
||||
|
||||
input = []C.llama_token{token}
|
||||
}
|
||||
@@ -228,6 +239,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
|
||||
timings := C.llama_get_timings(llm.ctx)
|
||||
fn(api.GenerateResponse{
|
||||
Done: true,
|
||||
Context: context.Data(),
|
||||
PromptEvalCount: int(timings.n_p_eval),
|
||||
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
|
||||
EvalCount: int(timings.n_eval),
|
||||
|
||||
Reference in New Issue
Block a user