Skip to content

fix(ollama): pre-imported funcs instrumentation failure #2871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,29 @@
]


def _sanitize_copy_messages(wrapped, instance, args, kwargs):
# original signature: _copy_messages(messages)
messages = args[0] if args else []
sanitized = []
for msg in messages or []:
if isinstance(msg, dict):
msg_copy = dict(msg)
tc_list = msg_copy.get("tool_calls")
if tc_list:
for tc in tc_list:
func = tc.get("function")
arg = func.get("arguments") if func else None
if isinstance(arg, str):
try:
func["arguments"] = json.loads(arg)
except Exception:
pass
sanitized.append(msg_copy)
else:
sanitized.append(msg)
return wrapped(sanitized)


def should_send_prompts():
return (
os.getenv("TRACELOOP_TRACE_CONTENT") or "true"
Expand Down Expand Up @@ -89,15 +112,18 @@ def _set_prompts(span, messages):
f"{prefix}.tool_calls.{i}.name",
function.get("name"),
)
# record arguments: ensure it's a JSON string for span attributes
raw_args = function.get("arguments")
if isinstance(raw_args, dict):
arg_str = json.dumps(raw_args)
else:
arg_str = raw_args
_set_span_attribute(
span,
f"{prefix}.tool_calls.{i}.arguments",
function.get("arguments"),
arg_str,
)

if function.get("arguments"):
function["arguments"] = json.loads(function.get("arguments"))


def set_tools_attributes(span, tools):
if not tools:
Expand All @@ -118,15 +144,15 @@ def set_tools_attributes(span, tools):

@dont_throw
def _set_input_attributes(span, llm_request_type, kwargs):
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, kwargs.get("model"))
json_data = kwargs.get("json", {})
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, json_data.get("model"))
_set_span_attribute(
span, SpanAttributes.LLM_IS_STREAMING, kwargs.get("stream") or False
)

if should_send_prompts():
if llm_request_type == LLMRequestTypeValues.CHAT:
_set_span_attribute(span, f"{SpanAttributes.LLM_PROMPTS}.0.role", "user")
for index, message in enumerate(kwargs.get("messages")):
for index, message in enumerate(json_data.get("messages")):
_set_span_attribute(
span,
f"{SpanAttributes.LLM_PROMPTS}.{index}.content",
Expand All @@ -137,13 +163,13 @@ def _set_input_attributes(span, llm_request_type, kwargs):
f"{SpanAttributes.LLM_PROMPTS}.{index}.role",
message.get("role"),
)
_set_prompts(span, kwargs.get("messages"))
if kwargs.get("tools"):
set_tools_attributes(span, kwargs.get("tools"))
_set_prompts(span, json_data.get("messages"))
if json_data.get("tools"):
set_tools_attributes(span, json_data.get("tools"))
else:
_set_span_attribute(span, f"{SpanAttributes.LLM_PROMPTS}.0.role", "user")
_set_span_attribute(
span, f"{SpanAttributes.LLM_PROMPTS}.0.content", kwargs.get("prompt")
span, f"{SpanAttributes.LLM_PROMPTS}.0.content", json_data.get("prompt")
)


Expand Down Expand Up @@ -240,7 +266,8 @@ def _accumulate_streaming_response(span, token_histogram, llm_request_type, resp
accumulated_response["message"]["content"] += res["message"]["content"]
accumulated_response["message"]["role"] = res["message"]["role"]
elif llm_request_type == LLMRequestTypeValues.COMPLETION:
accumulated_response["response"] += res["response"]
text = res.get("response", "")
accumulated_response["response"] += text

response_data = res.model_dump() if hasattr(res, 'model_dump') else res
_set_response_attributes(span, token_histogram, llm_request_type, response_data | accumulated_response)
Expand All @@ -260,7 +287,8 @@ async def _aaccumulate_streaming_response(span, token_histogram, llm_request_typ
accumulated_response["message"]["content"] += res["message"]["content"]
accumulated_response["message"]["role"] = res["message"]["role"]
elif llm_request_type == LLMRequestTypeValues.COMPLETION:
accumulated_response["response"] += res["response"]
text = res.get("response", "")
accumulated_response["response"] += text

response_data = res.model_dump() if hasattr(res, 'model_dump') else res
_set_response_attributes(span, token_histogram, llm_request_type, response_data | accumulated_response)
Expand Down Expand Up @@ -336,13 +364,11 @@ def _wrap(
if response:
if duration_histogram:
duration = end_time - start_time
duration_histogram.record(
duration,
attributes={
SpanAttributes.LLM_SYSTEM: "Ollama",
SpanAttributes.LLM_RESPONSE_MODEL: kwargs.get("model"),
},
)
attrs = {SpanAttributes.LLM_SYSTEM: "Ollama"}
model = kwargs.get("model")
if model is not None:
attrs[SpanAttributes.LLM_RESPONSE_MODEL] = model
duration_histogram.record(duration, attributes=attrs)

if span.is_recording():
if kwargs.get("stream"):
Expand Down Expand Up @@ -392,13 +418,11 @@ async def _awrap(
if response:
if duration_histogram:
duration = end_time - start_time
duration_histogram.record(
duration,
attributes={
SpanAttributes.LLM_SYSTEM: "Ollama",
SpanAttributes.LLM_RESPONSE_MODEL: kwargs.get("model"),
},
)
attrs = {SpanAttributes.LLM_SYSTEM: "Ollama"}
model = kwargs.get("model")
if model is not None:
attrs[SpanAttributes.LLM_RESPONSE_MODEL] = model
duration_histogram.record(duration, attributes=attrs)

if span.is_recording():
if kwargs.get("stream"):
Expand Down Expand Up @@ -459,23 +483,23 @@ def _instrument(self, **kwargs):
duration_histogram,
) = (None, None)

for wrapped_method in WRAPPED_METHODS:
wrap_method = wrapped_method.get("method")
wrap_function_wrapper(
"ollama._client",
f"Client.{wrap_method}",
_wrap(tracer, token_histogram, duration_histogram, wrapped_method),
)
wrap_function_wrapper(
"ollama._client",
f"AsyncClient.{wrap_method}",
_awrap(tracer, token_histogram, duration_histogram, wrapped_method),
)
wrap_function_wrapper(
"ollama",
f"{wrap_method}",
_wrap(tracer, token_histogram, duration_histogram, wrapped_method),
)
# Patch _copy_messages to sanitize tool_calls arguments before Pydantic validation
wrap_function_wrapper(
"ollama._client",
"_copy_messages",
_sanitize_copy_messages,
)
# instrument all llm methods (generate/chat/embeddings) via _request dispatch wrapper
wrap_function_wrapper(
"ollama._client",
"Client._request",
_dispatch_wrap(tracer, token_histogram, duration_histogram),
)
wrap_function_wrapper(
"ollama._client",
"AsyncClient._request",
_dispatch_awrap(tracer, token_histogram, duration_histogram),
)

def _uninstrument(self, **kwargs):
for wrapped_method in WRAPPED_METHODS:
Expand All @@ -491,3 +515,33 @@ def _uninstrument(self, **kwargs):
"ollama",
wrapped_method.get("method"),
)


def _dispatch_wrap(tracer, token_histogram, duration_histogram):
def wrapper(wrapped, instance, args, kwargs):
to_wrap = None
if len(args) > 2 and isinstance(args[2], str):
path = args[2]
op = path.rstrip('/').split('/')[-1]
to_wrap = next((m for m in WRAPPED_METHODS if m.get("method") == op), None)
if to_wrap:
return _wrap(tracer, token_histogram, duration_histogram, to_wrap)(
wrapped, instance, args, kwargs
)
return wrapped(*args, **kwargs)
return wrapper


def _dispatch_awrap(tracer, token_histogram, duration_histogram):
async def wrapper(wrapped, instance, args, kwargs):
to_wrap = None
if len(args) > 2 and isinstance(args[2], str):
path = args[2]
op = path.rstrip('/').split('/')[-1]
to_wrap = next((m for m in WRAPPED_METHODS if m.get("method") == op), None)
if to_wrap:
return await _awrap(tracer, token_histogram, duration_histogram, to_wrap)(
wrapped, instance, args, kwargs
)
return await wrapped(*args, **kwargs)
return wrapper
13 changes: 9 additions & 4 deletions packages/opentelemetry-instrumentation-ollama/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions packages/opentelemetry-instrumentation-ollama/tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
import ollama
from ollama import AsyncClient, chat
from opentelemetry.semconv_ai import SpanAttributes
from unittest.mock import MagicMock
from opentelemetry.instrumentation.ollama import _set_response_attributes
Expand All @@ -8,7 +8,7 @@

@pytest.mark.vcr
def test_ollama_chat(exporter):
response = ollama.chat(
response = chat(
model="llama3",
messages=[
{
Expand Down Expand Up @@ -45,7 +45,7 @@ def test_ollama_chat(exporter):

@pytest.mark.vcr
def test_ollama_chat_tool_calls(exporter):
ollama.chat(
chat(
model="llama3.1",
messages=[
{
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_ollama_chat_tool_calls(exporter):

@pytest.mark.vcr
def test_ollama_streaming_chat(exporter):
gen = ollama.chat(
gen = chat(
model="llama3",
messages=[
{
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_ollama_streaming_chat(exporter):
@pytest.mark.vcr
@pytest.mark.asyncio
async def test_ollama_async_chat(exporter):
client = ollama.AsyncClient()
client = AsyncClient()
response = await client.chat(
model="llama3",
messages=[
Expand Down Expand Up @@ -176,7 +176,7 @@ async def test_ollama_async_chat(exporter):
@pytest.mark.vcr
@pytest.mark.asyncio
async def test_ollama_async_streaming_chat(exporter):
client = ollama.AsyncClient()
client = AsyncClient()
gen = await client.chat(
model="llama3",
messages=[
Expand Down