函数修饰器

装饰器 @mlflow.trace 允许你为任何函数创建跨度。 函数修饰器提供添加跟踪的最简单路径,只需更改最少的代码。

  • MLflow 检测函数之间的父子关系,使其与自动跟踪集成兼容。
  • 捕获函数执行期间的异常,并将其记录为跨度事件。
  • 自动记录函数的名称、输入、输出和执行时间。
  • 可以与自动跟踪功能一起使用。

先决条件

MLflow 3

此页面需要以下包:

  • mlflow[databricks] 3.1 及更高版本:具有 GenAI 功能和 Databricks 连接的核心 MLflow 功能。
  • openai 1.0.0 及更高版本:(可选)仅当自定义代码与 OpenAI 交互时;如果需要,请替换为其他 SDK。

安装基本组件:

%pip install --upgrade "mlflow[databricks]>=3.1"
# %pip install --upgrade openai>=1.0.0 # Install if needed

MLflow 2.x

本指南需要以下包:

  • mlflow[databricks] 2.15.0 及更高版本:具有 Databricks 连接的核心 MLflow 功能。
  • openai 1.0.0 及更高版本。 (可选)仅当自定义代码与 OpenAI 交互时。

注释

Databricks 强烈建议安装 MLflow 3.1 或更高版本(如果使用 mlflow[databricks])。

安装基本组件:

%pip install --upgrade "mlflow[databricks]>=2.15.0,<3.0.0"
# pip install --upgrade openai>=1.0.0 # Install if needed

基本示例

以下代码是使用修饰器跟踪 Python 函数的最小示例。

小窍门

为了确保完全可观测性,如果使用多个修饰器, @mlflow.trace 修饰器通常应该是最外部的修饰器。 有关详细说明和示例,请参阅 Using @mlflow.trace with Other Decorators

import mlflow


@mlflow.trace(span_type="func", attributes={"key": "value"})
def add_1(x):
    return x + 1


@mlflow.trace(span_type="func", attributes={"key1": "value1"})
def minus_1(x):
    return x - 1


@mlflow.trace(name="Trace Test")
def trace_test(x):
    step1 = add_1(x)
    return minus_1(step1)


trace_test(4)

跟踪修饰器

注释

当跟踪包含多个具有相同名称的跨度时,MLflow 会向其追加自动递增后缀,例如_1_2

自定义范围

@mlflow.trace修饰器接受下列参数以自定义创建的跨度:

  • name 参数,用于替代默认范围名称(修饰的函数的名称)
  • span_type 用于设置范围类型的参数。 设置内置跨度类型之一或字符串。
  • attributes 用于向范围添加自定义属性的参数。

小窍门

与其他 @mlflow.trace 修饰器(例如来自 Web 框架)结合使用时,确保它位于最外层至关重要。 有关正确排序与错误排序的明确示例,请参阅 Using @mlflow.trace with Other Decorators

@mlflow.trace(
    name="call-local-llm", span_type=SpanType.LLM, attributes={"model": "gpt-4o-mini"}
)
def invoke(prompt: str):
    return client.invoke(
        messages=[{"role": "user", "content": prompt}], model="gpt-4o-mini"
    )

或者,可以使用 API 动态更新函数 mlflow.get_current_active_span 内的跨度。

@mlflow.trace(span_type=SpanType.LLM)
def invoke(prompt: str):
    model_id = "gpt-4o-mini"
    # Get the current span (created by the @mlflow.trace decorator)
    span = mlflow.get_current_active_span()
    # Set the attribute to the span
    span.set_attributes({"model": model_id})
    return client.invoke(messages=[{"role": "user", "content": prompt}], model=model_id)

与其他修饰器一起 @mlflow.trace 使用

在将多个修饰器应用于单个函数时,务必把@mlflow.trace放在作为最外层的修饰器位置(也就是最顶部)。 这可确保 MLflow 可以捕获函数的整个执行,包括任何内部修饰器的行为。

如果 @mlflow.trace 不是最外部的修饰器,则对函数的执行的可见性可能会受到限制或不正确,这可能会导致函数的输入、输出和执行时间不完整的跟踪或错误呈现。

请考虑以下概念示例:

import mlflow
import functools
import time

# A hypothetical additional decorator
def simple_timing_decorator(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds by simple_timing_decorator.")
        return result
    return wrapper

# Correct order: @mlflow.trace is outermost
@mlflow.trace(name="my_decorated_function_correct_order")
@simple_timing_decorator
# @another_framework_decorator # e.g., @app.route("/mypath") from Flask
def my_complex_function(x, y):
    # Function logic here
    time.sleep(0.1) # Simulate work
    return x + y

# Incorrect order: @mlflow.trace is NOT outermost
@simple_timing_decorator
@mlflow.trace(name="my_decorated_function_incorrect_order")
# @another_framework_decorator
def my_other_complex_function(x, y):
    time.sleep(0.1)
    return x * y

# Example calls
if __name__ == "__main__":
    print("Calling function with correct decorator order:")
    my_complex_function(5, 3)

    print("\nCalling function with incorrect decorator order:")
    my_other_complex_function(5, 3)

my_complex_function 示例(正确顺序)中,@mlflow.trace 将捕获完整执行,包括 simple_timing_decorator 添加的时间。 在my_other_complex_function(顺序不正确)中,MLflow 捕获的跟踪可能无法准确反映总的执行时间,或者在simple_timing_decorator看到它们之前,可能会漏掉由@mlflow.trace对输入/输出所做的修改。

添加跟踪标记

可以将标记添加到跟踪中,以便在跟踪级别提供其他元数据。 可通过几种不同的方法在跟踪上设置标记。 有关其他方法,请参阅 附加自定义标记指南

@mlflow.trace
def my_func(x):
    mlflow.update_current_trace(tags={"fruit": "apple"})
    return x + 1

在 UI 中自定义请求和响应预览

MLflow UI 中的“跟踪”选项卡显示跟踪列表,并且 RequestResponse 显示每个跟踪的端到端输入和输出的预览。 这使得你能快速了解每个项跟踪代表的含义。

默认情况下,这些预览会被截断为固定数量的字符。 您可以使用request_previewresponse_preview参数在mlflow.update_current_trace()函数中自定义这些列中显示的内容。 这对于复杂的输入或输出特别有用,其中默认截断可能不会显示最相关的信息。

下面是为处理长文档和用户指令的跟踪设置自定义请求预览的示例,目的是在 UI 的 Request 列中呈现最相关信息:

import mlflow

@mlflow.trace(name="Summarization Pipeline")
def summarize_document(document_content: str, user_instructions: str):
    # Construct a custom preview for the request column
    # For example, show beginning of document and user instructions
    request_p = f"Doc: {document_content[:30]}... Instr: {user_instructions[:30]}..."
    mlflow.update_current_trace(request_preview=request_p)

    # Simulate LLM call
    # messages = [
    #     {"role": "system", "content": "Summarize the following document based on user instructions."},
    #     {"role": "user", "content": f"Document: {document_content}\nInstructions: {user_instructions}"}
    # ]
    # completion = client.chat.completions.create(model="gpt-4o-mini", messages=messages)
    # summary = completion.choices[0].message.content
    summary = f"Summary of document starting with '{document_content[:20]}...' based on '{user_instructions}'"

    # Customize the response preview
    response_p = f"Summary: {summary[:50]}..."
    mlflow.update_current_trace(response_preview=response_p)

    return summary

# Example Call
long_document = "This is a very long document that contains many details about various topics..." * 10
instructions = "Focus on the key takeaways regarding topic X."
summary_result = summarize_document(long_document, instructions)
# print(summary_result)

通过在跟踪(通常是根范围)上设置 request_previewresponse_preview,可以控制在主跟踪列表视图中汇总总体交互的方式,从而一目了然地识别和理解各个跟踪。

自动异常处理

如果在处理跟踪检测操作期间引发 Exception,将会在用户界面中显示调用未成功,并且部分数据捕获将可用于协助调试。 此外,与引发的异常有关的详细信息将包含在部分完成的工作范围的 Events 中,进一步有助于识别代码中出问题的地方。

跟踪错误

与自动跟踪相结合

手动跟踪与 MLflow 的自动跟踪功能无缝集成。 请参阅 “合并手动跟踪和自动跟踪”。

复杂的工作流跟踪

对于具有多个步骤的复杂工作流,请使用嵌套跨度来捕获详细的执行流:

@mlflow.trace(name="data_pipeline")
def process_data_pipeline(data_source: str):
    # Extract phase
    with mlflow.start_span(name="extract") as extract_span:
        raw_data = extract_from_source(data_source)
        extract_span.set_outputs({"record_count": len(raw_data)})

    # Transform phase
    with mlflow.start_span(name="transform") as transform_span:
        transformed = apply_transformations(raw_data)
        transform_span.set_outputs({"transformed_count": len(transformed)})

    # Load phase
    with mlflow.start_span(name="load") as load_span:
        result = load_to_destination(transformed)
        load_span.set_outputs({"status": "success"})

    return result

多线程处理

MLflow 追踪是线程安全的,默认情况下,每个线程的追踪都是隔离的。 但是,你也可以通过一些额外的步骤来创建一个跨多个线程的跟踪。

MLflow 使用 Python 的内置 ContextVar 机制来确保线程安全性,默认情况下不会跨线程传播。 因此,需要手动将上下文从主线程复制到工作线程,如以下示例所示。

import contextvars
from concurrent.futures import ThreadPoolExecutor, as_completed
import mlflow
from mlflow.entities import SpanType
import openai

client = openai.OpenAI()

# Enable MLflow Tracing for OpenAI
mlflow.openai.autolog()


@mlflow.trace
def worker(question: str) -> str:
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": question},
    ]
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages,
        temperature=0.1,
        max_tokens=100,
    )
    return response.choices[0].message.content


@mlflow.trace
def main(questions: list[str]) -> list[str]:
    results = []
    # Almost same as how you would use ThreadPoolExecutor, but two additional steps
    #  1. Copy the context in the main thread using copy_context()
    #  2. Use ctx.run() to run the worker in the copied context
    with ThreadPoolExecutor(max_workers=2) as executor:
        futures = []
        for question in questions:
            ctx = contextvars.copy_context()
            futures.append(executor.submit(ctx.run, worker, question))
        for future in as_completed(futures):
            results.append(future.result())
    return results


questions = [
    "What is the capital of France?",
    "What is the capital of Germany?",
]

main(questions)

多线程跟踪

小窍门

相比之下,ContextVar 默认情况下会被复制到 异步 任务。 因此,在使用 asyncioMLflow 跟踪时,无需手动复制上下文,这可能是使用 MLflow 跟踪在 Python 中处理并发 I/O 绑定任务的更简单方法。

流式处理输出

自 MLflow 2.20.2 起,@mlflow.trace 修饰器可用于跟踪那些返回生成器或迭代器的函数。

@mlflow.trace
def stream_data():
    for i in range(5):
        yield i

以上示例将为 stream_data 函数生成一条包含单个范围的跟踪记录。 默认情况下,MLflow 将捕获生成器产生的所有元素,作为区间输出的列表。 在上面的示例中,范围输出将是 [0, 1, 2, 3, 4]

注释

流函数的时区将在返回的迭代器开始消费时启动,并在迭代器耗尽或在迭代期间引发异常时结束。

使用输出减缩器

如果要将元素聚合为单个范围输出,可以使用 output_reducer 参数指定自定义函数来聚合元素。 自定义函数应将生成元素的列表作为输入。

from typing import List, Any

@mlflow.trace(output_reducer=lambda x: ",".join(x))
def stream_data():
    for c in "hello":
        yield c

在上面的示例中,范围输出将是 "h,e,l,l,o"。 原始区块仍可在 MLflow 跟踪 UI 的跨度选项卡中找到 Events ,允许在调试时检查单个生成值。

常见的输出化简器模式

下面是实现输出化简器的一些常见模式:

令牌聚合

from typing import List, Dict, Any

def aggregate_tokens(chunks: List[str]) -> str:
    """Concatenate streaming tokens into complete text"""
    return "".join(chunks)

@mlflow.trace(output_reducer=aggregate_tokens)
def stream_text():
    for word in ["Hello", " ", "World", "!"]:
        yield word

指标聚合

def aggregate_metrics(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Aggregate streaming metrics into summary statistics"""
    values = [c["value"] for c in chunks if "value" in c]
    return {
        "count": len(values),
        "sum": sum(values),
        "average": sum(values) / len(values) if values else 0,
        "max": max(values) if values else None,
        "min": min(values) if values else None
    }

@mlflow.trace(output_reducer=aggregate_metrics)
def stream_metrics():
    for i in range(10):
        yield {"value": i * 2, "timestamp": time.time()}

错误收集

def collect_results_and_errors(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Separate successful results from errors"""
    results = []
    errors = []

    for chunk in chunks:
        if chunk.get("error"):
            errors.append(chunk["error"])
        else:
            results.append(chunk.get("data"))

    return {
        "results": results,
        "errors": errors,
        "success_rate": len(results) / len(chunks) if chunks else 0,
        "has_errors": len(errors) > 0
    }

高级示例:OpenAI 流式处理

下面是一个高级示例,该示例使用 output_reducer 将 OpenAI LLM 中的 ChatCompletionChunk 输出合并到单个消息对象中。

小窍门

对于生产用例,建议 对 OpenAI 使用自动跟踪 ,以便自动处理此情况。 下面的示例用于演示目的。

import mlflow
import openai
from openai.types.chat import *
from typing import Optional


def aggregate_chunks(outputs: list[ChatCompletionChunk]) -> Optional[ChatCompletion]:
    """Consolidate ChatCompletionChunks to a single ChatCompletion"""
    if not outputs:
        return None

    first_chunk = outputs[0]
    delta = first_chunk.choices[0].delta
    message = ChatCompletionMessage(
        role=delta.role, content=delta.content, tool_calls=delta.tool_calls or []
    )
    finish_reason = first_chunk.choices[0].finish_reason
    for chunk in outputs[1:]:
        delta = chunk.choices[0].delta
        message.content += delta.content or ""
        message.tool_calls += delta.tool_calls or []
        finish_reason = finish_reason or chunk.choices[0].finish_reason

    base = ChatCompletion(
        id=first_chunk.id,
        choices=[Choice(index=0, message=message, finish_reason=finish_reason)],
        created=first_chunk.created,
        model=first_chunk.model,
        object="chat.completion",
    )
    return base


@mlflow.trace(output_reducer=aggregate_chunks)
def predict(messages: list[dict]):
    client = openai.OpenAI()
    stream = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages,
        stream=True,
    )
    for chunk in stream:
        yield chunk


for chunk in predict([{"role": "user", "content": "Hello"}]):
    print(chunk)

在上面的示例中,生成的 predict 范围将具有单个聊天完成消息作为输出,该输出由自定义化简器函数聚合。

实际用例

下面是常见 GenAI 方案输出化简器的其他示例:

使用 JSON 分析的 LLM 响应

from typing import List, Dict, Any
import json

def parse_json_from_llm(content: str) -> str:
    """Extract and clean JSON from LLM responses that may include markdown"""
    # Remove common markdown code block wrappers
    if content.startswith("```json") and content.endswith("```"):
        content = content[7:-3]  # Remove ```json prefix and ``` suffix
    elif content.startswith("```") and content.endswith("```"):
        content = content[3:-3]  # Remove generic ``` wrappers
    return content.strip()

def json_stream_reducer(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Aggregate LLM streaming output and parse JSON response"""
    full_content = ""
    metadata = {}
    errors = []

    # Process different chunk types
    for chunk in chunks:
        chunk_type = chunk.get("type", "content")

        if chunk_type == "content" or chunk_type == "token":
            full_content += chunk.get("content", "")
        elif chunk_type == "metadata":
            metadata.update(chunk.get("data", {}))
        elif chunk_type == "error":
            errors.append(chunk.get("error"))

    # Return early if errors occurred
    if errors:
        return {
            "status": "error",
            "errors": errors,
            "raw_content": full_content,
            **metadata
        }

    # Try to parse accumulated content as JSON
    try:
        cleaned_content = parse_json_from_llm(full_content)
        parsed_data = json.loads(cleaned_content)

        return {
            "status": "success",
            "data": parsed_data,
            "raw_content": full_content,
            **metadata
        }
    except json.JSONDecodeError as e:
        return {
            "status": "parse_error",
            "error": f"Failed to parse JSON: {str(e)}",
            "raw_content": full_content,
            **metadata
        }

@mlflow.trace(output_reducer=json_stream_reducer)
def generate_structured_output(prompt: str, schema: dict):
    """Generate structured JSON output from an LLM"""
    # Simulate streaming JSON generation
    yield {"type": "content", "content": '{"name": "John", '}
    yield {"type": "content", "content": '"email": "john@example.com", '}
    yield {"type": "content", "content": '"age": 30}'}

    # Add metadata
    trace_id = mlflow.get_current_active_span().request_id if mlflow.get_current_active_span() else None
    yield {"type": "metadata", "data": {"trace_id": trace_id, "model": "gpt-4"}}

使用 OpenAI 生成结构化输出

下面是将输出化简器与 OpenAI 配合使用以生成和分析结构化 JSON 响应的完整示例:

import json
import mlflow
import openai
from typing import List, Dict, Any, Optional

def structured_output_reducer(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Aggregate streaming chunks into structured output with comprehensive error handling.
    Handles token streaming, metadata collection, and JSON parsing.
    """
    content_parts = []
    trace_id = None
    model_info = None
    errors = []

    for chunk in chunks:
        chunk_type = chunk.get("type", "token")

        if chunk_type == "token":
            content_parts.append(chunk.get("content", ""))
        elif chunk_type == "trace_info":
            trace_id = chunk.get("trace_id")
            model_info = chunk.get("model")
        elif chunk_type == "error":
            errors.append(chunk.get("message"))

    # Join all content parts
    full_content = "".join(content_parts)

    # Base response
    response = {
        "trace_id": trace_id,
        "model": model_info,
        "raw_content": full_content
    }

    # Handle errors
    if errors:
        response["status"] = "error"
        response["errors"] = errors
        return response

    # Try to extract and parse JSON
    try:
        # Clean markdown wrappers if present
        json_content = full_content.strip()
        if json_content.startswith("```json") and json_content.endswith("```"):
            json_content = json_content[7:-3].strip()
        elif json_content.startswith("```") and json_content.endswith("```"):
            json_content = json_content[3:-3].strip()

        parsed_data = json.loads(json_content)
        response["status"] = "success"
        response["data"] = parsed_data

    except json.JSONDecodeError as e:
        response["status"] = "parse_error"
        response["error"] = f"JSON parsing failed: {str(e)}"
        response["error_position"] = e.pos if hasattr(e, 'pos') else None

    return response

@mlflow.trace(output_reducer=structured_output_reducer)
async def generate_customer_email(
    customer_name: str,
    issue: str,
    sentiment: str = "professional"
) -> None:
    """
    Generate a structured customer service email response.
    Demonstrates real-world streaming with OpenAI and structured output parsing.
    """
    client = openai.AsyncOpenAI()

    system_prompt = """You are a customer service assistant. Generate a professional email response in JSON format:
    {
        "subject": "email subject line",
        "greeting": "personalized greeting",
        "body": "main email content addressing the issue",
        "closing": "professional closing",
        "priority": "high|medium|low"
    }"""

    user_prompt = f"Customer: {customer_name}\nIssue: {issue}\nTone: {sentiment}"

    try:
        # Stream the response
        stream = await client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            stream=True,
            temperature=0.7
        )

        # Yield streaming tokens
        async for chunk in stream:
            if chunk.choices[0].delta.content:
                yield {
                    "type": "token",
                    "content": chunk.choices[0].delta.content
                }

        # Add trace metadata
        if current_span := mlflow.get_current_active_span():
            yield {
                "type": "trace_info",
                "trace_id": current_span.request_id,
                "model": "gpt-4o-mini"
            }

    except Exception as e:
        yield {
            "type": "error",
            "message": f"OpenAI API error: {str(e)}"
        }

# Example usage
async def main():
    # This will automatically aggregate the streamed output into structured JSON
    async for chunk in generate_customer_email(
        customer_name="John Doe",
        issue="Product arrived damaged",
        sentiment="empathetic"
    ):
        # In practice, you might send these chunks to a frontend
        print(chunk.get("content", ""), end="", flush=True)

注释

此示例展示了多个实际模式:

  • 流式处理 UI 更新:令牌可随到随显
  • 结构化输出验证:JSON 分析可确保响应格式
  • 错误复原:妥善处理 API 错误和解析失败
  • 跟踪关联:将流式处理输出链接到 MLflow 跟踪以供调试

多模型响应聚合

def multi_model_reducer(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Aggregate responses from multiple models"""
    responses = {}
    latencies = {}

    for chunk in chunks:
        model = chunk.get("model")
        if model:
            responses[model] = chunk.get("response", "")
            latencies[model] = chunk.get("latency", 0)

    return {
        "responses": responses,
        "latencies": latencies,
        "fastest_model": min(latencies, key=latencies.get) if latencies else None,
        "consensus": len(set(responses.values())) == 1
    }

测试输出化简器

输出化简器可以独立于跟踪框架进行测试,因此可以轻松确保它们正确处理边缘事例:

import unittest
from typing import List, Dict, Any

def my_reducer(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Example reducer to be tested"""
    if not chunks:
        return {"status": "empty", "total": 0}

    total = sum(c.get("value", 0) for c in chunks)
    errors = [c for c in chunks if c.get("error")]

    return {
        "status": "error" if errors else "success",
        "total": total,
        "count": len(chunks),
        "average": total / len(chunks) if chunks else 0,
        "error_count": len(errors)
    }

class TestOutputReducer(unittest.TestCase):
    def test_normal_case(self):
        chunks = [
            {"value": 10},
            {"value": 20},
            {"value": 30}
        ]
        result = my_reducer(chunks)
        self.assertEqual(result["status"], "success")
        self.assertEqual(result["total"], 60)
        self.assertEqual(result["average"], 20.0)

    def test_empty_input(self):
        result = my_reducer([])
        self.assertEqual(result["status"], "empty")
        self.assertEqual(result["total"], 0)

    def test_error_handling(self):
        chunks = [
            {"value": 10},
            {"error": "Network timeout"},
            {"value": 20}
        ]
        result = my_reducer(chunks)
        self.assertEqual(result["status"], "error")
        self.assertEqual(result["total"], 30)
        self.assertEqual(result["error_count"], 1)

    def test_missing_values(self):
        chunks = [
            {"value": 10},
            {"metadata": "some info"},  # No value field
            {"value": 20}
        ]
        result = my_reducer(chunks)
        self.assertEqual(result["total"], 30)
        self.assertEqual(result["count"], 3)

小窍门

  • 输出化简器一次接收内存中的所有区块。 对于非常大的流,请考虑实现流式替代方案或分块方法。
  • 在生成器被完全消耗之前,范围保持打开状态,这会影响延迟指标。
  • 化简器应该是无状态的,并避免对可预测行为产生副作用。

支持的函数类型

@mlflow.trace修饰器当前支持以下类型的函数:

函数类型 已支持
同步 是的
异步 是 (MLflow >= 2.16.0)
Generator 是 (MLflow >= 2.20.2)
异步生成器 是 (MLflow >= 2.20.2)

后续步骤