Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def __init__(
yarn_orig_ctx: int = 0,
logits_all: bool = False,
embedding: bool = False,
n_seq_max: Optional[int] = None,
kv_unified: Optional[bool] = None,
offload_kqv: bool = True,
flash_attn: bool = False,
op_offload: Optional[bool] = None,
Expand Down Expand Up @@ -172,6 +174,8 @@ def __init__(
yarn_orig_ctx: YaRN original context size
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
embedding: Embedding mode only.
n_seq_max: Maximum number of sequences in KV cache
kv_unified: Use unified KV cache across sequences
offload_kqv: Offload K, Q, V to GPU.
flash_attn: Use flash attention.
op_offload: offload host tensor operations to device
Expand Down Expand Up @@ -343,6 +347,14 @@ def __init__(
self.context_params.offload_kqv = offload_kqv
self.context_params.flash_attn = flash_attn

# this allows for batch embedding many sequences
if n_seq_max is not None:
self.context_params.n_seq_max = n_seq_max
if kv_unified is not None:
self.context_params.kv_unified = kv_unified
elif embedding and n_seq_max is None:
self.context_params.kv_unified = True

if op_offload is not None:
self.context_params.op_offload = op_offload

Expand Down
Loading