diff --git a/README.md b/README.md index e0232478c75a2..0401723ffcf87 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,9 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ## Hot topics +- 🔥 Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md) - **GGML developer experience survey (organized and reviewed by NVIDIA):** [link](https://forms.gle/Gasw3cRgyhNEnrwK9) -- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141]((https://github.com/ggml-org/llama.cpp/pull/13141))), `libllava` will be deprecated +- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141](https://github.com/ggml-org/llama.cpp/pull/13141)), `libllava` will be deprecated - VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode - Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639 - Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim diff --git a/common/arg.cpp b/common/arg.cpp index 73a3cfe5392c0..f67e0d96d702a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -40,7 +40,7 @@ using json = nlohmann::ordered_json; std::initializer_list mmproj_examples = { LLAMA_EXAMPLE_LLAVA, - // TODO: add LLAMA_EXAMPLE_SERVER when it's ready + LLAMA_EXAMPLE_SERVER, }; static std::string read_file(const std::string & fname) { diff --git a/docs/multimodal.md b/docs/multimodal.md new file mode 100644 index 0000000000000..efed473a3cd07 --- /dev/null +++ b/docs/multimodal.md @@ -0,0 +1,69 @@ +# Multimodal + +llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools support this feature: +- [llama-mtmd-cli](../tools/mtmd/README.md) +- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API + +To enable it, can use use one of the 2 methods below: + +- Use `-hf` option with a [supported model](../../docs/multimodal.md) + - To load a model using `-hf` while disabling multimodal, use `--no-mmproj` + - To load a model using `-hf` while using a custom mmproj file, use `--mmproj local_file.gguf` +- Use `-m model.gguf` option with `--mmproj file.gguf` to specify text and multimodal projector respectively + +By default, multimodal projector will be offloaded to GPU. To disable this, add `--no-mmproj-offload` + +For example: + +```sh +# simple usage with CLI +llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF + +# simple usage with server +llama-server -hf ggml-org/gemma-3-4b-it-GGUF + +# using local file +llama-server -m gemma-3-4b-it-Q4_K_M.gguf --mmproj mmproj-gemma-3-4b-it-Q4_K_M.gguf + +# no GPU offload +llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload +``` + +## Pre-quantized models + +These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. + +Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server` + +NOTE: some models may require large context window, for example: `-c 8192` + +```sh +# Gemma 3 +(tool_name) -hf ggml-org/gemma-3-4b-it-GGUF +(tool_name) -hf ggml-org/gemma-3-12b-it-GGUF +(tool_name) -hf ggml-org/gemma-3-27b-it-GGUF + +# SmolVLM +(tool_name) -hf ggml-org/SmolVLM-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM-256M-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM-500M-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM2-2.2B-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM2-256M-Video-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF + +# Pixtral 12B +(tool_name) -hf ggml-org/pixtral-12b-GGUF + +# Qwen 2 VL +(tool_name) -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF + +# Qwen 2.5 VL +(tool_name) -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF + +# Mistral Small 3.1 24B (IQ2_M quantization) +(tool_name) -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF +``` diff --git a/tools/mtmd/README.md b/tools/mtmd/README.md index 20e7696cefd8e..06e1fd097423a 100644 --- a/tools/mtmd/README.md +++ b/tools/mtmd/README.md @@ -16,38 +16,7 @@ The naming and structure related to multimodal support have evolved, which might ## Pre-quantized models -These are ready-to-use models, most of them come with `Q4_K_M` quantization by default: - -```sh -# Gemma 3 -llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF -llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF -llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF - -# SmolVLM -llama-mtmd-cli -hf ggml-org/SmolVLM-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/SmolVLM-256M-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/SmolVLM-500M-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/SmolVLM2-2.2B-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/SmolVLM2-256M-Video-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF - -# Pixtral 12B -llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF - -# Qwen 2 VL -llama-mtmd-cli -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF - -# Qwen 2.5 VL -llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF - -# Mistral Small 3.1 24B (IQ2_M quantization) -llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF -``` +See the list of pre-quantized model [here](../../docs/multimodal.md) ## How it works and what is `mmproj`? diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index aee90388e4fb3..17109fddbd307 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -34,8 +34,9 @@ endforeach() add_executable(${TARGET} ${TARGET_SRCS}) install(TARGETS ${TARGET} RUNTIME) +target_include_directories(${TARGET} PRIVATE ../llava) target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) -target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT}) if (LLAMA_SERVER_SSL) find_package(OpenSSL REQUIRED) diff --git a/tools/server/README.md b/tools/server/README.md index 0ec786ea76f7a..972ca384e69a9 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -193,6 +193,12 @@ services: LLAMA_ARG_PORT: 8080 ``` +### Multimodal support + +Multimodal support was added in [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) and is currently an experimental feature. + +For more details, please refer to [multimodal documentation](../../docs/multimodal.md) + ## Build `llama-server` is built alongside everything else from the root of the project @@ -749,6 +755,9 @@ This endpoint is public (no API key check). By default, it is read-only. To make "total_slots": 1, "model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", "chat_template": "...", + "modalities": { + "vision": false + }, "build_info": "b(build number)-(build commit hash)" } ``` @@ -757,6 +766,7 @@ This endpoint is public (no API key check). By default, it is read-only. To make - `total_slots` - the total number of slots for process requests (defined by `--parallel` option) - `model_path` - the path to model file (same with `-m` argument) - `chat_template` - the model's original Jinja2 prompt template +- `modalities` - the list of supported modalities ### POST `/props`: Change server global properties. @@ -1069,6 +1079,8 @@ print(completion.choices[0].text) Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used. +If model supports multimodal, you can input the media file via `image_url` content part. We support both base64 and remote URL as input. See OAI documentation for more. + *Options:* See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported. diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 06788bbdc8545..de8ded71fd6ad 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -7,6 +7,7 @@ #include "log.h" #include "sampling.h" #include "speculative.h" +#include "mtmd.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT @@ -197,8 +198,8 @@ struct server_task { int id_target = -1; // used by SERVER_TASK_TYPE_INFERENCE - slot_params params; - llama_tokens prompt_tokens; + slot_params params; + server_tokens prompt_tokens; int id_selected_slot = -1; // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE @@ -1248,6 +1249,9 @@ struct server_slot { llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; + // multimodal + mtmd_context * mctx = nullptr; + common_speculative * spec = nullptr; std::vector lora; @@ -1275,14 +1279,14 @@ struct server_slot { int32_t n_prompt_tokens_processed = 0; // input prompt tokens - llama_tokens prompt_tokens; + server_tokens prompt_tokens; size_t last_nl_pos = 0; std::string generated_text; llama_tokens generated_tokens; - llama_tokens cache_tokens; + server_tokens cache_tokens; std::vector generated_token_probs; @@ -1476,7 +1480,7 @@ struct server_slot { {"is_processing", is_processing()}, {"non_causal", is_non_causal()}, {"params", params.to_json()}, - {"prompt", common_detokenize(ctx, prompt_tokens)}, + {"prompt", prompt_tokens.detokenize(ctx, true)}, {"next_token", { {"has_next_token", has_next_token}, @@ -1849,13 +1853,16 @@ struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; + // multimodal + mtmd_context * mctx = nullptr; + const llama_vocab * vocab = nullptr; llama_model * model_dft = nullptr; llama_context_params cparams_dft; - llama_batch batch = {}; + llama_batch batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1878,6 +1885,8 @@ struct server_context { common_chat_templates_ptr chat_templates; ~server_context() { + mtmd_free(mctx); + // Clear any sampling context for (server_slot & slot : slots) { common_sampler_free(slot.smpl); @@ -1965,6 +1974,36 @@ struct server_context { chat_templates = common_chat_templates_init(model, "chatml"); } + std::string & mmproj_path = params_base.mmproj.path; + if (!mmproj_path.empty()) { + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params_base.mmproj_use_gpu; + mparams.print_timings = false; + mparams.n_threads = params_base.cpuparams.n_threads; + mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + if (mctx == nullptr) { + SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + return false; + } + SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); + } + + if (!params_base.speculative.model.path.empty()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); + return false; + } + } + return true; } @@ -1980,6 +2019,8 @@ struct server_context { slot.ctx = ctx; slot.n_ctx = n_ctx_slot; slot.n_predict = params_base.n_predict; + slot.mctx = mctx; + slot.cache_tokens.has_mtmd = mctx != nullptr; if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); @@ -2016,8 +2057,6 @@ struct server_context { // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { const int32_t n_batch = llama_n_batch(ctx); - - // only a single seq_id per token is needed batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); } @@ -2054,7 +2093,7 @@ struct server_context { } // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens); + int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); // fraction of the common subsequence length compared to the current slot's prompt length float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); @@ -2096,18 +2135,6 @@ struct server_context { return ret; } - bool can_be_detokenized(const struct llama_context * ctx, const std::vector & tokens) { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - const int32_t n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & token : tokens) { - if (token < 0 || token >= n_vocab) { - return false; - } - } - return true; - } - bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); slot.id_task = task.id; @@ -2122,8 +2149,7 @@ struct server_context { slot.lora = slot.params.lora; } - bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens); - if (!can_detokenize) { + if (!slot.prompt_tokens.validate(ctx)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; } @@ -2385,6 +2411,15 @@ struct server_context { queue_results.send(std::move(res)); } + // if multimodal is enabled, send an error and return false + bool ensure_no_mtmd(const int id_task) { + if (mctx) { + send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + return false; + } + return true; + } + void send_partial_response(server_slot & slot, const completion_token_output & tkn) { auto res = std::make_unique(); @@ -2424,7 +2459,7 @@ struct server_context { res->content = std::move(slot.generated_text); res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->prompt = slot.prompt_tokens.detokenize(ctx, true); res->response_fields = std::move(slot.params.response_fields); res->truncated = slot.truncated; @@ -2734,6 +2769,10 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_SAVE: { + if (!ensure_no_mtmd(task.id)) { + break; + } + int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2753,7 +2792,8 @@ struct server_context { std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); + const llama_tokens & tokens = slot->cache_tokens.get_text_tokens(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -2770,6 +2810,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { + if (!ensure_no_mtmd(task.id)) break; int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2788,15 +2829,18 @@ struct server_context { std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - slot->cache_tokens.resize(slot->n_ctx); + llama_tokens tokens; + tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { - slot->cache_tokens.resize(0); + slot->cache_tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } - slot->cache_tokens.resize(token_count); + tokens.resize(token_count); + slot->cache_tokens.clear(); + slot->cache_tokens.insert(tokens); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -2813,6 +2857,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_ERASE: { + if (!ensure_no_mtmd(task.id)) break; int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2844,6 +2889,7 @@ struct server_context { res->id = task.id; queue_results.send(std::move(res)); } break; + } } @@ -2889,6 +2935,12 @@ struct server_context { continue; } + if (mctx) { + // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + // Shift context const int n_keep = slot.params.n_keep + add_bos_token; const int n_left = slot.n_past - n_keep; @@ -2900,11 +2952,14 @@ struct server_context { llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); if (slot.params.cache_prompt) { - for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy + for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { + new_tokens[i - n_discard] = new_tokens[i]; } - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + new_tokens.resize(slot.cache_tokens.size() - n_discard); + slot.cache_tokens.clear(); + slot.cache_tokens.insert(new_tokens); } slot.n_past -= n_discard; @@ -2982,7 +3037,7 @@ struct server_context { SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); // print prompt tokens (for debugging) - if (1) { + /*if (1) { // first 16 tokens (avoid flooding logs) for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); @@ -2992,7 +3047,7 @@ struct server_context { for (int i = 0; i < (int) prompt_tokens.size(); i++) { SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); } - } + }*/ // empty prompt passed -> release the slot and send empty response if (prompt_tokens.empty()) { @@ -3034,21 +3089,27 @@ struct server_context { // if input prompt is too big, truncate it if (slot.n_prompt_tokens >= slot.n_ctx) { + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens(); llama_tokens new_tokens( - prompt_tokens.begin(), - prompt_tokens.begin() + slot.params.n_keep); + curr_tokens.begin(), + curr_tokens.begin() + slot.params.n_keep); new_tokens.insert( new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - prompt_tokens.end()); + curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + curr_tokens.end()); - prompt_tokens = std::move(new_tokens); + prompt_tokens.clear(); + prompt_tokens.insert(new_tokens); slot.truncated = true; slot.n_prompt_tokens = prompt_tokens.size(); @@ -3060,13 +3121,18 @@ struct server_context { if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); + slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params_base.n_cache_reuse > 0) { size_t head_c = slot.n_past; // cache size_t head_p = slot.n_past; // current prompt + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); while (head_c < slot.cache_tokens.size() && @@ -3092,7 +3158,7 @@ struct server_context { llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { - slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; + slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); slot.n_past++; } @@ -3140,21 +3206,52 @@ struct server_context { // remove the non-common part from the cache slot.cache_tokens.resize(slot.n_past); + // check if we should process the image + if (slot.n_past < slot.n_prompt_tokens + && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + // process the image + int32_t new_n_past; + int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); + int32_t n_pos = new_n_past - slot.n_past; + + if (res != 0) { + SLT_ERR(slot, "failed to process image, res = %d\n", res); + slot.release(); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + continue; + } + + if (slot.params.cache_prompt) { + const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); + slot.cache_tokens.push_back(chunk.get()); // copy + } + + slot.n_past += n_pos; + slot.n_prompt_tokens_processed += n_pos; + } + // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + // get next token to process + llama_token cur_tok = slot.prompt_tokens[slot.n_past]; + if (cur_tok == LLAMA_TOKEN_NULL) { + break; // end of text chunk + } + // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); - + common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); + slot.cache_tokens.push_back(cur_tok); } slot.n_prompt_tokens_processed++; slot.n_past++; } + // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed @@ -3162,12 +3259,16 @@ struct server_context { slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size()); common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system for (int i = 0; i < slot.n_prompt_tokens; ++i) { - common_sampler_accept(slot.smpl, prompt_tokens[i], false); + llama_token id = slot.prompt_tokens[i]; + if (id != LLAMA_TOKEN_NULL) { + common_sampler_accept(slot.smpl, id, false); + } } // extract the logits only for the last token @@ -3320,6 +3421,11 @@ struct server_context { continue; } + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } + // determine the max draft that fits the current slot state int n_draft_max = slot.params.speculative.n_max; @@ -3346,7 +3452,8 @@ struct server_context { params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; params_spec.p_min = slot.params.speculative.p_min; - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); + const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // keep track of total number of tokens generated in the draft slot.n_draft_total += draft.size(); @@ -3380,7 +3487,7 @@ struct server_context { slot.n_draft_accepted += ids.size() - 1; slot.cache_tokens.push_back(id); - slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1); @@ -3903,6 +4010,7 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model.path }, + { "modalities", json{{"vision", ctx_server.mctx != nullptr}} }, // TODO: add more in the future { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, @@ -3950,9 +4058,10 @@ int main(int argc, char ** argv) { const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( server_task_type type, json & data, + const std::vector & files, const std::function & is_connection_closed, httplib::Response & res, - oaicompat_type oaicompat) { + oaicompat_type oaicompat) -> void { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); if (ctx_server.params_base.embedding) { @@ -3969,15 +4078,69 @@ int main(int argc, char ** argv) { // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); - tasks.reserve(tokenized_prompts.size()); - for (size_t i = 0; i < tokenized_prompts.size(); i++) { + // process files + mtmd::bitmaps bitmaps; + const bool has_mtmd = ctx_server.mctx != nullptr; + { + if (!has_mtmd && !files.empty()) { + throw std::runtime_error("This server does not support multimodal"); + } + for (auto & file : files) { + mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size())); + if (!bmp.ptr) { + throw std::runtime_error("Failed to load image"); + } + // calculate bitmap hash (for KV caching) + std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3); + bmp.set_id(hash.c_str()); + bitmaps.entries.push_back(std::move(bmp)); + } + } + + // process prompt + std::vector inputs; + if (oaicompat && !prompt.is_string()) { + throw std::runtime_error("prompt must be a string"); + } + + if (oaicompat && has_mtmd) { + // multimodal + std::string prompt_str = prompt.get(); + mtmd_input_text inp_txt = { + prompt_str.c_str(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd::input_chunks chunks(mtmd_input_chunks_init()); + auto bitmaps_c_ptr = bitmaps.c_ptr(); + int32_t tokenized = mtmd_tokenize(ctx_server.mctx, + chunks.ptr.get(), + &inp_txt, + bitmaps_c_ptr.data(), + bitmaps_c_ptr.size()); + if (tokenized != 0) { + throw std::runtime_error("Failed to tokenize prompt"); + } + + server_tokens tmp(chunks, true); + inputs.push_back(std::move(tmp)); + } else { + // non-multimodal version + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + for (auto & p : tokenized_prompts) { + auto tmp = server_tokens(p, ctx_server.mctx != nullptr); + inputs.push_back(std::move(tmp)); + } + } + + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.prompt_tokens = std::move(inputs[i]); task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, @@ -4059,9 +4222,11 @@ int main(int argc, char ** argv) { const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - return handle_completions_impl( + std::vector files; // dummy + handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_NONE); @@ -4069,9 +4234,11 @@ int main(int argc, char ** argv) { const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = oaicompat_completion_params_parse(json::parse(req.body)); - return handle_completions_impl( + std::vector files; // dummy + handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_COMPLETION); @@ -4146,9 +4313,11 @@ int main(int argc, char ** argv) { tokenized_prompts[0] ); - return handle_completions_impl( + std::vector files; // dummy + handle_completions_impl( SERVER_TASK_TYPE_INFILL, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_NONE); // infill is not OAI compatible @@ -4162,11 +4331,19 @@ int main(int argc, char ** argv) { } auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); + std::vector files; + json data = oaicompat_completion_params_parse( + body, + params.use_jinja, + params.reasoning_format, + ctx_server.chat_templates.get(), + ctx_server.mctx, + files); - return handle_completions_impl( + handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_CHAT); @@ -4175,7 +4352,14 @@ int main(int argc, char ** argv) { // same with handle_chat_completions, but without inference part const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); + std::vector files; // dummy, unused + json data = oaicompat_completion_params_parse( + body, + params.use_jinja, + params.reasoning_format, + ctx_server.chat_templates.get(), + ctx_server.mctx, + files); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; @@ -4280,7 +4464,7 @@ int main(int argc, char ** argv) { } } - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); for (const auto & tokens : tokenized_prompts) { // this check is necessary for models that do not add BOS token to the input if (tokens.empty()) { @@ -4300,7 +4484,7 @@ int main(int argc, char ** argv) { task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr); // OAI-compat task.params.oaicompat = oaicompat; @@ -4394,13 +4578,14 @@ int main(int argc, char ** argv) { std::unordered_set task_ids; { std::vector tasks; - std::vector tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); + auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); tasks.reserve(tokenized_docs.size()); for (size_t i = 0; i < tokenized_docs.size(); i++) { + auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); + task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr); tasks.push_back(std::move(task)); } diff --git a/tools/server/tests/unit/test_vision_api.py b/tools/server/tests/unit/test_vision_api.py new file mode 100644 index 0000000000000..7cc4096f19e0c --- /dev/null +++ b/tools/server/tests/unit/test_vision_api.py @@ -0,0 +1,59 @@ +import pytest +from utils import * +import base64 +import requests + +server: ServerProcess + +IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" +IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" + +response = requests.get(IMG_URL_0) +response.raise_for_status() # Raise an exception for bad status codes +IMG_BASE64_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") + + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinygemma3() + + +@pytest.mark.parametrize( + "prompt, image_url, success, re_content", + [ + # test model is trained on CIFAR-10, but it's quite dumb due to small size + ("What is this:\n", IMG_URL_0, True, "(cat)+"), + ("What is this:\n", "IMG_BASE64_0", True, "(cat)+"), # exceptional, so that we don't cog up the log + ("What is this:\n", IMG_URL_1, True, "(frog)+"), + ("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache + ("What is this:\n", "malformed", False, None), + ("What is this:\n", "https://google.com/404", False, None), # non-existent image + ("What is this:\n", "https://ggml.ai", False, None), # non-image data + ] +) +def test_vision_chat_completion(prompt, image_url, success, re_content): + global server + server.start(timeout_seconds=60) # vision model may take longer to load due to download size + if image_url == "IMG_BASE64_0": + image_url = IMG_BASE64_0 + res = server.make_request("POST", "/chat/completions", data={ + "temperature": 0.0, + "top_k": 1, + "messages": [ + {"role": "user", "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": { + "url": image_url, + }}, + ]}, + ], + }) + if success: + assert res.status_code == 200 + choice = res.body["choices"][0] + assert "assistant" == choice["message"]["role"] + assert match_regex(re_content, choice["message"]["content"]) + else: + assert res.status_code != 200 + diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 4dc2062a8e5b9..27a0f0356aae1 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -88,6 +88,7 @@ class ServerProcess: chat_template: str | None = None chat_template_file: str | None = None server_path: str | None = None + mmproj_url: str | None = None # session variables process: subprocess.Popen | None = None @@ -194,6 +195,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--chat-template", self.chat_template]) if self.chat_template_file: server_args.extend(["--chat-template-file", self.chat_template_file]) + if self.mmproj_url: + server_args.extend(["--mmproj-url", self.mmproj_url]) args = [str(arg) for arg in [server_path, *server_args]] print(f"tests: starting server with: {' '.join(args)}") @@ -379,6 +382,21 @@ def jina_reranker_tiny() -> ServerProcess: server.server_reranking = True return server + @staticmethod + def tinygemma3() -> ServerProcess: + server = ServerProcess() + # mmproj is already provided by HF registry API + server.model_hf_repo = "ggml-org/tinygemma3-GGUF" + server.model_hf_file = "tinygemma3-Q8_0.gguf" + server.mmproj_url = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/mmproj-tinygemma3.gguf" + server.model_alias = "tinygemma3" + server.n_ctx = 1024 + server.n_batch = 32 + server.n_slots = 2 + server.n_predict = 4 + server.seed = 42 + return server + def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]: """ diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index b497959fd8689..23163f4fe939e 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -3,7 +3,9 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "arg.h" // common_remote_get_content #include "base64.hpp" +#include "mtmd.h" // increase max payload length to allow use of larger context size #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 @@ -21,6 +23,7 @@ #include #include #include +#include #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" @@ -41,6 +44,8 @@ using json = nlohmann::ordered_json; #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +using raw_buffer = std::vector; + template static T json_value(const json & body, const std::string & key, const T & default_value) { // Fallback null to default value @@ -386,7 +391,7 @@ static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); } -static inline std::vector base64_decode(const std::string & encoded_string) { +static inline raw_buffer base64_decode(const std::string & encoded_string) { int i = 0; int j = 0; int in_ = 0; @@ -396,7 +401,7 @@ static inline std::vector base64_decode(const std::string & encoded_str uint8_t char_array_4[4]; uint8_t char_array_3[3]; - std::vector ret; + raw_buffer ret; while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { char_array_4[i++] = encoded_string[in_]; in_++; @@ -579,7 +584,9 @@ static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ bool use_jinja, common_reasoning_format reasoning_format, - const struct common_chat_templates * tmpls) + const struct common_chat_templates * tmpls, + bool allow_non_text, + std::vector & out_files) { json llama_params; @@ -627,8 +634,77 @@ static json oaicompat_completion_params_parse( } } + // get input files + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + json messages = body.at("messages"); + if (!messages.is_array()) { + throw std::runtime_error("Expected 'messages' to be an array"); + } + for (auto & msg : messages) { + json & content = msg.at("content"); + if (content.is_string()) { + continue; + } + + if (!content.is_array()) { + throw std::runtime_error("Expected 'content' to be a string or an array"); + } + + for (auto & p : content) { + std::string type = json_value(p, "type", std::string()); + json image_url = json_value(p, "image_url", json::object()); + if (type == "image_url") { + if (!allow_non_text) { + throw std::runtime_error("image input is not supported by this server"); + } + + std::string url = json_value(image_url, "url", std::string()); + if (string_starts_with(url, "http")) { + // download remote image + // TODO @ngxson : maybe make these params configurable + common_remote_params params; + params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.max_size = 1024 * 1024 * 10; // 10MB + params.timeout = 10; // seconds + SRV_INF("downloading image from '%s'\n", url.c_str()); + auto res = common_remote_get_content(url, params); + if (200 <= res.first && res.first < 300) { + SRV_INF("downloaded %ld bytes\n", res.second.size()); + raw_buffer data; + data.insert(data.end(), res.second.begin(), res.second.end()); + out_files.push_back(data); + } else { + throw std::runtime_error("Failed to download image"); + } + + } else { + // try to decode base64 image + std::vector parts = string_split(url, /*separator*/ ','); + if (parts.size() != 2) { + throw std::runtime_error("Invalid image_url.url value"); + } else if (!string_starts_with(parts[0], "data:image/")) { + throw std::runtime_error("Invalid image_url.url format: " + parts[0]); + } else if (!string_ends_with(parts[0], "base64")) { + throw std::runtime_error("image_url.url must be base64 encoded"); + } else { + auto base64_data = parts[1]; + auto decoded_data = base64_decode(base64_data); + out_files.push_back(decoded_data); + } + } + + // replace this chunk with a marker + p["type"] = "text"; + p["text"] = MTMD_DEFAULT_IMAGE_MARKER; + p.erase("image_url"); + } + } + } + common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages")); + inputs.messages = common_chat_msgs_parse_oaicompat(messages); inputs.tools = common_chat_tools_parse_oaicompat(tools); inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); @@ -935,3 +1011,286 @@ static std::vector parse_lora_request( return lora; } + +// +// utils for interacting with libmtmd +// (may need to refactor in near future) +// + +/** + * server_tokens is a helper to manage the input tokens and image for the server. + * it is made this way to simplify the logic of KV cache management. + */ +struct server_tokens { + bool has_mtmd = false; + +private: // disallow accessing these members directly, risking out-of-sync + + // map a **start** position in tokens to the image chunk + std::unordered_map map_pos_to_image; + + // list of tokens + // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token + // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** + // important: for models using mrope, an image can contain multiple tokens but will use only one **position** + llama_tokens tokens; + + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // pos 0 1 2 3 4 5 6 7 8 9 + // map_pos_to_image will contain: {5, img0}, {8, img1} + +public: + server_tokens() = default; + ~server_tokens() = default; + + // Prevent copying + server_tokens(const server_tokens&) = delete; + server_tokens& operator=(const server_tokens&) = delete; + + // Allow moving (usually implicitly generated if members are movable) + server_tokens(server_tokens&&) = default; + server_tokens& operator=(server_tokens&&) = default; + + // Allow accessing elements using [] operator + llama_token operator[](size_t index) { return tokens[index]; } + const llama_token& operator[](size_t index) const { return tokens[index]; } + + server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { + for (size_t i = 0; i < mtmd_chunks.size(); ++i) { + push_back(mtmd_chunks[i]); + } + } + + server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + + // for debugging + std::string str() const { + std::ostringstream oss; + oss << "tokens: "; + for (const auto & t : tokens) { + if (t == LLAMA_TOKEN_NULL) { + oss << " "; + } else { + oss << t << " "; + } + } + oss << "\n"; + oss << "image pos: "; + for (const auto & it : map_pos_to_image) { + oss << it.first << ", "; + } + return oss.str(); + } + + const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const { + auto it = map_pos_to_image.find(pos); + if (it != map_pos_to_image.end()) { + return it->second; + } else { + throw std::runtime_error("Chunk not found"); + } + } + + void push_back(llama_token tok) { + if (tok == LLAMA_TOKEN_NULL) { + throw std::runtime_error("Invalid token"); + } + tokens.emplace_back(tok); + } + + // will create a copy of the chunk if it contains non-text data + void push_back(const mtmd_input_chunk * chunk) { + auto type = mtmd_input_chunk_get_type(chunk); + if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + GGML_ASSERT(has_mtmd); + auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk); + const int n_pos = mtmd_image_tokens_get_n_pos(img_tokens); + llama_pos start_pos = tokens.size(); + for (int i = 0; i < n_pos; ++i) { + tokens.emplace_back(LLAMA_TOKEN_NULL); + } + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_pos_to_image[start_pos] = std::move(new_chunk); + } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + for (size_t i = 0; i < n_tokens; ++i) { + push_back(text_tokens[i]); + } + } else { + GGML_ABORT("Invalid chunk type"); + } + } + + // for compatibility with context shift and prompt truncation + void insert(const llama_tokens & inp_tokens) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); + } + + // for compatibility with speculative decoding, ctx shift, slot save/load + const llama_tokens & get_text_tokens() const { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + return tokens; + } + + // for compatibility with speculative decoding + void set_token(llama_pos pos, llama_token id) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens[pos] = id; + } + + size_t size() const { + return tokens.size(); + } + + bool empty() const { + return tokens.empty(); + } + + void clear() { + tokens.clear(); + } + + void resize(size_t n) { + GGML_ASSERT(n <= tokens.size()); + if (has_mtmd) { + // we throw an error if we try to remove a token in the middle of an image + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // n 1 2 3 4 5 6 7 8 9 10 + // allowed to resize ^ ^ + // disallowed to resize ^ ^ ^ + if (n > 0) { + llama_token last_token = tokens[n - 1]; + // make sure we never remove tokens in the middle of an image + if (last_token == LLAMA_TOKEN_NULL) { + find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk + } + } + // remove all image chunks that are not used anymore + for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) { + llama_pos pos = it->first; + if (pos >= (llama_pos)n) { + it = map_pos_to_image.erase(it); + } else { + ++it; + } + } + } + tokens.resize(n); + } + + std::string detokenize(const llama_context * ctx, bool special) const { + llama_tokens text_tokens; + text_tokens.reserve(tokens.size()); + for (const auto & t : tokens) { + if (t != LLAMA_TOKEN_NULL) { + text_tokens.push_back(t); + } + } + return common_detokenize(ctx, text_tokens, special); + } + + size_t get_common_prefix(const server_tokens & b) const { + size_t max_idx = std::min(tokens.size(), b.tokens.size()); + for (size_t i = 0; i < max_idx; ++i) { + auto & ai = tokens[i]; + auto & bi = b.tokens[i]; + + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + GGML_ASSERT(has_mtmd); + const auto & a_chunk = find_chunk(i); + const auto & b_chunk = b.find_chunk(i); + GGML_ASSERT(a_chunk && b_chunk); + const auto * a_img = mtmd_input_chunk_get_tokens_image(a_chunk.get()); + const auto * b_img = mtmd_input_chunk_get_tokens_image(b_chunk.get()); + std::string ai_id = mtmd_image_tokens_get_id(a_img); + std::string bi_id = mtmd_image_tokens_get_id(b_img); + size_t a_pos = mtmd_image_tokens_get_n_pos(a_img); + size_t b_pos = mtmd_image_tokens_get_n_pos(b_img); + if (ai_id == bi_id && a_pos == b_pos) { + GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen + i += a_pos - 1; // will be +1 by the for loop + continue; + } else { + return i; + } + } else if (ai == bi) { + continue; + } else { + return i; + } + } + return max_idx; // all tokens are equal + } + + // make sure all text tokens are within the vocab range + bool validate(const struct llama_context * ctx) const { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + + for (size_t i = 0; i < tokens.size(); ++i) { + auto & t = tokens[i]; + if (t == LLAMA_TOKEN_NULL) { + try { + const auto & chunk = find_chunk(i); + const auto * img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get()); + size_t n_pos = mtmd_image_tokens_get_n_pos(img_tokens); + i += n_pos - 1; // will be +1 by the for loop + } catch (const std::exception & e) { + return false; + } + } else if (t < 0 || t >= n_vocab) { + return false; + } + } + return true; + } + + // encode and decode the image chunk + int32_t process_chunk( + llama_context * ctx, + mtmd_context * mctx, + llama_pos n_past, + int32_t seq_id, + llama_pos & n_pos_out) { + auto it = map_pos_to_image.find(n_past); + if (it == map_pos_to_image.end()) { + throw std::runtime_error("Chunk not found"); + } + SRV_INF("%s\n", "processing image..."); + int32_t n_batch = llama_n_batch(ctx); + int64_t t0 = ggml_time_ms(); + llama_pos new_n_past = n_past; + int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, + it->second.get(), // chunk + n_past, + seq_id, + n_batch, + true, // logits last + &new_n_past); + SRV_INF("image processed in %" PRId64 " ms\n", ggml_time_ms() - t0); + if (result != 0) { + LOG_ERR("mtmd_helper_eval failed with status %d", result); + n_pos_out = n_past; + return result; + } + n_pos_out = new_n_past; + return 0; + } +}; + +// Computes FNV-1a hash of the data +static std::string fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return std::to_string(hash); +}