Compartir a través de


Memoria del agente de IA

Importante

Esta característica está en versión preliminar pública.

La memoria permite a los agentes de inteligencia artificial recordar información anterior de la conversación o de conversaciones anteriores. Esto permite a los agentes proporcionar respuestas basadas en contexto y crear experiencias personalizadas a lo largo del tiempo. Use Databricks Lakebase, una base de datos OLTP de Postgres totalmente administrada, para administrar el estado y el historial de la conversación.

Requisitos

Memoria a corto plazo frente a memoria a largo plazo

La memoria a corto plazo captura el contexto en una sola sesión de conversación mientras que la memoria a largo plazo extrae y almacena información clave en varias conversaciones. Puede compilar el agente con uno o ambos tipos de memoria.

Agentes con memoria a corto y largo plazo

Memoria a corto plazo Memoria a largo plazo
Captura de contexto en una sola sesión de conversación mediante identificadores de subprocesos y puntos de control
Mantener el contexto de las preguntas de seguimiento dentro de una sesión
Depuración y prueba de flujos de conversación mediante viajes de tiempo
Extracción y almacenamiento automático de información clave en varias sesiones
Personalización de interacciones basadas en preferencias anteriores
Creación de una base de conocimiento sobre los usuarios que mejoran las respuestas a lo largo del tiempo

Ejemplos de cuadernos

Agente con memoria a corto plazo

Obtener el cuaderno

Agente con memoria a largo plazo

Obtener el cuaderno

Consulta a tu agente implementado

Después de implementar su agente en un endpoint de servicio de modelos, consulte Consultar un agente de Mosaic AI implementado para obtener instrucciones de consulta.

Para pasar un identificador de subproceso, use el parámetro extra_body. En el ejemplo siguiente se muestra cómo introducir un ID de subproceso en un ResponsesAgent punto de conexión.

   response1 = client.responses.create(
    model=endpoint,
    input=[{"role": "user", "content": "What are stateful agents?"}],
    extra_body={
        "custom_inputs": {"thread_id": thread_id}
    }
)

Si utiliza un cliente que pasa automáticamente ChatContext como Playground o Review, el identificador de conversación y el identificador de usuario se pasarán automáticamente para casos de uso de memoria a corto o largo plazo.

Viaje de tiempo de memoria a corto plazo

En el caso de los agentes con memoria a corto plazo, use LangGraph time-travel para reanudar la ejecución desde los puntos de control. Puede reproducir la conversación o modificarla para explorar rutas de acceso alternativas. Cada vez que se reanuda desde un punto de control, LangGraph crea una nueva bifurcación en el historial de conversaciones, conservando el original al tiempo que habilita la experimentación.

  1. En el código del agente, cree funciones que recuperen el historial de puntos de comprobación y actualicen el estado del punto de control en la LangGraphResponsesAgent clase :

    from typing import List, Dict
    def get_checkpoint_history(self, thread_id: str, limit: int = 10) -> List[Dict[str, Any]]:
        """Retrieve checkpoint history for a thread.
    
        Args:
            thread_id: The thread identifier
            limit: Maximum number of checkpoints to return
    
        Returns:
            List of checkpoint information including checkpoint_id, timestamp, and next nodes
        """
        config = {"configurable": {"thread_id": thread_id}}
    
       with CheckpointSaver(instance_name=LAKEBASE_INSTANCE_NAME) as checkpointer:
            graph = self._create_graph(checkpointer)
    
            history = []
            for state in graph.get_state_history(config):
                if len(history) >= limit:
                    break
    
                history.append({
                    "checkpoint_id": state.config["configurable"]["checkpoint_id"],
                    "thread_id": thread_id,
                    "timestamp": state.created_at,
                    "next_nodes": state.next,
                    "message_count": len(state.values.get("messages", [])),
                    # Include last message summary for context
                    "last_message": self._get_last_message_summary(state.values.get("messages", []))
                })
    
            return history
    
    def _get_last_message_summary(self, messages: List[Any]) -> Optional[str]:
        """Get a snippet of the last message for checkpoint identification"""
        return getattr(messages[-1], "content", "")[:100] if messages else None
    
    def update_checkpoint_state(self, thread_id: str, checkpoint_id: str,
                            new_messages: Optional[List[Dict]] = None) -> Dict[str, Any]:
        """Update state at a specific checkpoint (used for modifying conversation history).
    
        Args:
            thread_id: The thread identifier
            checkpoint_id: The checkpoint to update
            new_messages: Optional new messages to set at this checkpoint
    
        Returns:
            New checkpoint configuration including the new checkpoint_id
        """
        config = {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_id": checkpoint_id
            }
        }
    
        with CheckpointSaver(instance_name=LAKEBASE_INSTANCE_NAME) as checkpointer:
            graph = self._create_graph(checkpointer)
    
            # Prepare the values to update
            values = {}
            if new_messages:
                cc_msgs = self.prep_msgs_for_cc_llm(new_messages)
                values["messages"] = cc_msgs
    
            # Update the state (creates a new checkpoint)
            new_config = graph.update_state(config, values=values)
    
            return {
                "thread_id": thread_id,
                "checkpoint_id": new_config["configurable"]["checkpoint_id"],
                "parent_checkpoint_id": checkpoint_id
            }
    
  2. Actualice las funciones predict y predict_stream para admitir pasar puntos de control.

    Predict

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        """Non-streaming prediction"""
        # The same thread_id is used by BOTH predict() and predict_stream()
        ci = dict(request.custom_inputs or {})
        if "thread_id" not in ci:
            ci["thread_id"] = str(uuid.uuid4())
        request.custom_inputs = ci
    
        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
    
        # Include thread_id and checkpoint_id in custom outputs
        custom_outputs = {
            "thread_id": ci["thread_id"]
        }
        if "checkpoint_id" in ci:
            custom_outputs["parent_checkpoint_id"] = ci["checkpoint_id"]
    
        try:
            history = self.get_checkpoint_history(ci["thread_id"], limit=1)
            if history:
                custom_outputs["checkpoint_id"] = history[0]["checkpoint_id"]
        except Exception as e:
            logger.warning(f"Could not retrieve new checkpoint_id: {e}")
    
        return ResponsesAgentResponse(output=outputs, custom_outputs=custom_outputs)
    

    Predict_stream

    def predict_stream(
        self,
        request: ResponsesAgentRequest,
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        """Streaming prediction with PostgreSQL checkpoint branching support.
    
        Accepts in custom_inputs:
        - thread_id: Conversation thread identifier for session
        - checkpoint_id (optional): Checkpoint to resume from (for branching)
        """
        # Get thread ID and checkpoint ID from custom inputs
        custom_inputs = request.custom_inputs or {}
        thread_id = custom_inputs.get("thread_id", str(uuid.uuid4()))  # generate new thread ID if one is not passed in
        checkpoint_id = custom_inputs.get("checkpoint_id")  # Optional for branching
    
        # Convert incoming Responses messages to LangChain format
        langchain_msgs = self.prep_msgs_for_cc_llm([i.model_dump() for i in request.input])
    
        # Build checkpoint configuration
        checkpoint_config = {"configurable": {"thread_id": thread_id}}
        # If checkpoint_id is provided, we're branching from that checkpoint
        if checkpoint_id:
            checkpoint_config["configurable"]["checkpoint_id"] = checkpoint_id
            logger.info(f"Branching from checkpoint: {checkpoint_id} in thread: {thread_id}")
    
        # DATABASE CONNECTION POOLING LOGIC FOLLOWS
        # Use connection from pool
    

A continuación, pruebe la ramificación del punto de control:

  1. Inicie un subproceso conversacional y agregue algunos mensajes:

    from agent import AGENT
    # Initial conversation - starts a new thread
    response1 = AGENT.predict({
        "input": [{"role": "user", "content": "I'm planning for an upcoming trip!"}],
    })
    print(response1.model_dump(exclude_none=True))
    thread_id = response1.custom_outputs["thread_id"]
    
    # Within the same thread, ask a follow-up question - short-term memory will remember previous messages in the same thread/conversation session
    response2 = AGENT.predict({
        "input": [{"role": "user", "content": "I'm headed to SF!"}],
        "custom_inputs": {"thread_id": thread_id}
    })
    print(response2.model_dump(exclude_none=True))
    
    # Within the same thread, ask a follow-up question - short-term memory will remember previous messages in the same thread/conversation session
    response3 = AGENT.predict({
        "input": [{"role": "user", "content": "Where did I say I'm going?"}],
        "custom_inputs": {"thread_id": thread_id}
    })
    print(response3.model_dump(exclude_none=True))
    
    
  2. Recupere el historial de puntos de control y bifurque la conversación con un mensaje diferente:

    # Get checkpoint history to find branching point
    history = AGENT.get_checkpoint_history(thread_id, 20)
    # Retrieve checkpoint at index - indices count backward from most recent checkpoint
    index = max(1, len(history) - 4)
    branch_checkpoint = history[index]["checkpoint_id"]
    
    # Branch from node with next_node = `('__start__',)` to re-input message to agent at certain part of conversation
    # I want to update the information of which city I am going to
    # Within the same thread, branch from a checkpoint and override it with different context to continue the conversation in a new fork
    response4 = AGENT.predict({
        "input": [{"role": "user", "content": "I'm headed to New York!"}],
        "custom_inputs": {
            "thread_id": thread_id,
            "checkpoint_id": branch_checkpoint # Branch from this checkpoint!
        }
    })
    print(response4.model_dump(exclude_none=True))
    
    # Thread ID stays the same even though it branched from a checkpoint:
    branched_thread_id = response4.custom_outputs["thread_id"]
    print(f"original thread id was {thread_id}")
    print(f"new thread id after branching is the same as original: {branched_thread_id}")
    
    # Continue the conversation in the same thread and it will pick up from the information you tell it in your branch
    response5 = AGENT.predict({
        "input": [{"role": "user", "content": "Where am I going?"}],
        "custom_inputs": {
            "thread_id": thread_id,
        }
    })
    print(response5.model_dump(exclude_none=True))
    

Pasos siguientes