diff --git a/docs/chat_message_history.ipynb b/docs/chat_message_history.ipynb index 8b1a4cf2..91ca31fd 100644 --- a/docs/chat_message_history.ipynb +++ b/docs/chat_message_history.ipynb @@ -1,79 +1,331 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Google DATABASE\n", - "\n", - "[Google DATABASE](https://cloud.google.com/DATABASE).\n", - "\n", - "Save chat messages into `DATABASE`." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Pre-reqs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "%pip install PACKAGE_NAME" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from PACKAGE import LOADER" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Basic Usage" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "# Cloud Spanner\n", + "> [Cloud Spanner](https://cloud.google.com/spanner) is a highly scalable database that combines unlimited scalability with relational semantics, such as secondary indexes, strong consistency, schemas, and SQL providing 99.999% availability in one easy solution.\n", + "\n", + "This notebook goes over how to use `Spanner` to store chat message history with the `SpannerChatMessageHistory` class." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Before You Begin\n", + "\n", + "To run this notebook, you will need to do the following:\n", + " * [Create a Google Cloud Project](https://developers.google.com/workspace/guides/create-project)\n", + " * [Create a Spanner instance](https://cloud.google.com/spanner/docs/create-manage-instances)\n", + " * [Create a Spanner database](https://cloud.google.com/spanner/docs/create-manage-databases)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🦜🔗 Library Installation\n", + "The integration lives in its own `langchain-google-spanner` package, so we need to install it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet langchain-google-spanner" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # Automatically restart kernel after installs so that your environment can access the new packages\n", + "# import IPython\n", + "\n", + "# app = IPython.Application.instance()\n", + "# app.kernel.do_shutdown(True)" + ] + }, + { + "cell_type": "markdown", + "id": "yygMe6rPWxHS", + "metadata": { + "id": "yygMe6rPWxHS" + }, + "source": [ + "### 🔐 Authentication\n", + "Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n", + "\n", + "* If you are using Colab to run this notebook, use the cell below and continue.\n", + "* If you are using Vertex AI Workbench, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "PTXN1_DSXj2b", + "metadata": { + "id": "PTXN1_DSXj2b" + }, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "id": "NEvB9BoLEulY", + "metadata": { + "id": "NEvB9BoLEulY" + }, + "source": [ + "### ☁ Set Your Google Cloud Project\n", + "Set your Google Cloud project so that you can leverage Google Cloud resources within this notebook.\n", + "\n", + "If you don't know your project ID, try the following:\n", + "\n", + "* Run `gcloud config list`.\n", + "* Run `gcloud projects list`.\n", + "* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "gfkS3yVRE4_W", + "metadata": { + "cellView": "form", + "id": "gfkS3yVRE4_W" + }, + "outputs": [], + "source": [ + "# @markdown Please fill in the value below with your Google Cloud project ID and then run the cell.\n", + "\n", + "PROJECT_ID = \"my-project-id\" # @param {type:\"string\"}\n", + "\n", + "# Set the project id\n", + "!gcloud config set project {PROJECT_ID}" + ] + }, + { + "cell_type": "markdown", + "id": "rEWWNoNnKOgq", + "metadata": { + "id": "rEWWNoNnKOgq" + }, + "source": [ + "### 💡 API Enablement\n", + "The `langchain-google-spanner` package requires that you [enable the Spanner API](https://console.cloud.google.com/flows/enableapi?apiid=spanner.googleapis.com) in your Google Cloud Project." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5utKIdq7KYi5", + "metadata": { + "id": "5utKIdq7KYi5" + }, + "outputs": [], + "source": [ + "# enable Cloud SQL Admin API\n", + "!gcloud services enable spanner.googleapis.com" + ] + }, + { + "cell_type": "markdown", + "id": "f8f2830ee9ca1e01", + "metadata": { + "id": "f8f2830ee9ca1e01" + }, + "source": [ + "## Basic Usage" + ] + }, + { + "cell_type": "markdown", + "id": "OMvzMWRrR6n7", + "metadata": { + "id": "OMvzMWRrR6n7" + }, + "source": [ + "### Set Cloud SQL database values\n", + "Find your database values, in the [Cloud SQL Instances page](https://console.cloud.google.com/sql?_ga=2.223735448.2062268965.1707700487-2088871159.1707257687)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "irl7eMFnSPZr", + "metadata": { + "id": "irl7eMFnSPZr" + }, + "outputs": [], + "source": [ + "# @title Set Your Values Here { display-mode: \"form\" }\n", + "INSTANCE = \"my-instance\" # @param {type: \"string\"}\n", + "DATABASE = \"my-database\" # @param {type: \"string\"}\n", + "TABLE_NAME = \"message_store\" # @param {type: \"string\"}" + ] + }, + { + "cell_type": "markdown", + "id": "qPV8WfWr7O54", + "metadata": { + "id": "qPV8WfWr7O54" + }, + "source": [ + "### Initialize a table\n", + "The `SpannerChatMessageHistory` class requires a database table with a specific schema in order to store the chat message history.\n", + "\n", + "The helper method `init_chat_history_table()` that can be used to create a table with the proper schema for you." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "TEu4VHArRttE", + "metadata": { + "id": "TEu4VHArRttE" + }, + "outputs": [], + "source": [ + "from langchain_google_spanner import (\n", + " SpannerChatMessageHistory,\n", + ")\n", + "\n", + "SpannerChatMessageHistory.init_chat_history_table(table_name=TABLE_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### SpannerChatMessageHistory\n", + "\n", + "To initialize the `SpannerChatMessageHistory` class you need to provide only 3 things:\n", + "\n", + "1. `instance_id` - The name of the Spanner instance\n", + "1. `database_id` - The name of the Spanner database\n", + "1. `session_id` - A unique identifier string that specifies an id for the session.\n", + "1. `table_name` - The name of the table within the database to store the chat message history." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "message_history = SpannerChatMessageHistory(\n", + " instance_id=INSTANCE,\n", + " database_id=DATABASE,\n", + " table_name=TABLE_NAME,\n", + " session_id=\"user-session-id\",\n", + ")\n", + "\n", + "message_history.add_user_message(\"hi!\")\n", + "message_history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "message_history.messages" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom client\n", + "The client created by default is the default client. To use a non-default, a [custom client](https://cloud.google.com/spanner/docs/samples/spanner-create-client-with-query-options#spanner_create_client_with_query_options-python) can be passed to the constructor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from google.cloud import spanner\n", + "\n", + "custom_client_message_history = SpannerChatMessageHistory(\n", + " instance_id=\"my-instance\",\n", + " database_id=\"my-database\",\n", + " client=spanner.Client(...),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleaning up\n", + "\n", + "When the history of a specific session is obsolete and can be deleted, it can be done the following way.\n", + "Note: Once deleted, the data is no longer stored in Cloud Spanner and is gone forever." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "message_history = SpannerChatMessageHistory(\n", + " instance_id=INSTANCE,\n", + " database_id=DATABASE,\n", + " table_name=TABLE_NAME,\n", + " session_id=\"user-session-id\",\n", + ")\n", + "\n", + "message_history.clear()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index ea22ec19..bc991c43 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -28,7 +28,6 @@ steps: - 'GOOGLE_DATABASE=${_GOOGLE_DATABASE}' - 'PG_DATABASE=${_PG_DATABASE}' - 'TABLE_NAME=test_$BUILD_ID' - args: ["-m", "pytest"] substitutions: _INSTANCE_ID: test-instance diff --git a/src/langchain_google_spanner/__init__.py b/src/langchain_google_spanner/__init__.py index 51b2fcdf..16d8f772 100644 --- a/src/langchain_google_spanner/__init__.py +++ b/src/langchain_google_spanner/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from langchain_google_spanner.chat_message_history import SpannerChatMessageHistory + from .version import __version__ -__all__ = ["__version__"] +__all__ = ["__version__", "SpannerChatMessageHistory"] diff --git a/src/langchain_google_spanner/chat_message_history.py b/src/langchain_google_spanner/chat_message_history.py new file mode 100644 index 00000000..c46d7fc7 --- /dev/null +++ b/src/langchain_google_spanner/chat_message_history.py @@ -0,0 +1,194 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cloud Spanner-based chat message history""" +from __future__ import annotations + +from typing import List, Optional + +from google.cloud import spanner +from google.cloud.spanner_admin_database_v1.types import DatabaseDialect # type: ignore +from google.cloud.spanner_v1 import param_types +from google.cloud.spanner_v1.data_types import JsonObject +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage, messages_from_dict + +from .version import __version__ + +USER_AGENT_CHAT = "langchain-google-spanner-python:chat_history" + __version__ + +OPERATION_TIMEOUT_SECONDS = 240 + +COLUMN_FAMILY = "langchain" +COLUMN_NAME = "history" + + +def client_with_user_agent( + client: Optional[spanner.Client], user_agent: str +) -> spanner.Client: + if not client: + client = spanner.Client() + client_agent = client._client_info.user_agent + if not client_agent: + client._client_info.user_agent = user_agent + elif user_agent not in client_agent: + client._client_info.user_agent = " ".join([client_agent, user_agent]) + return client + + +class SpannerChatMessageHistory(BaseChatMessageHistory): + """Chat message history that stores history in Spanner. + + Args: + instance_id: The Spanner instance to use for chat message history. + database_id: The Spanner database to use for chat message history. + table_name: The Spanner table to use for chat message history. + session_id: Optional. The existing session ID. + """ + + def __init__( + self, + instance_id: str, + database_id: str, + session_id: str, + table_name: str, + client: Optional[spanner.Client] = None, + ) -> None: + self.instance_id = instance_id + self.database_id = database_id + self.session_id = session_id + self.table_name = table_name + self.client = client_with_user_agent(client, USER_AGENT_CHAT) + self.instance = self.client.instance(instance_id) + if not self.instance.exists(): + raise Exception("Instance doesn't exist.") + self.database = self.instance.database(database_id) + if not self.database.exists(): + raise Exception("Database doesn't exist.") + self.database.reload() + self.dialect = self.database.database_dialect + self._verify_schema() + + def _verify_schema(self) -> None: + """Verify table exists with required schema for SpannerChatMessageHistory class. + Use helper method MSSQLEngine.create_chat_history_table(...) to create + table with valid schema. + """ + # check table exists + column_names = [] # type: List[str] + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql( + f"SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.columns WHERE table_name = '{self.table_name}'" + ) + for row in results: + column_names.append(*row) + + # check that all required columns are present + required_columns = ["id", "session_id", "created_at", "message"] + if len(column_names) == 0: + raise AttributeError( + f"Table '{self.table_name}' does not exist. Please create " + "it before initializing SpannerChatMessageHistory. See " + "SpannerEngine.create_chat_history_table() for a helper method." + ) + else: + if not (all(x in column_names for x in required_columns)): + google_schema = f"""CREATE TABLE IF NOT EXISTS {self.table_name} ( + id STRING(36) DEFAULT (GENERATE_UUID()), + created_at TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true), + session_id STRING(MAX) NOT NULL, + message JSON NOT NULL, + ) PRIMARY KEY (session_id, created_at ASC, id)""" + pg_schema = f"""CREATE TABLE IF NOT EXISTS {self.table_name} ( + id varchar(36) DEFAULT (spanner.generate_uuid()), + created_at SPANNER.COMMIT_TIMESTAMP NOT NULL, + session_id TEXT NOT NULL, + message JSONB NOT NULL, + PRIMARY KEY (session_id, created_at, id)""" + ddl = ( + pg_schema + if self.dialect == DatabaseDialect.POSTGRESQL + else google_schema + ) + raise IndexError( + f"Table '{self.table_name}' has incorrect schema. Got " + f"column names '{column_names}' but required column names " + f"'{required_columns}'.\nPlease create table with following schema:" + f"{ddl};" + ) + + def create_chat_history_table(self) -> None: + google_schema = f"""CREATE TABLE IF NOT EXISTS {self.table_name} ( + id STRING(36) DEFAULT (GENERATE_UUID()), + created_at TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true), + session_id STRING(MAX) NOT NULL, + message JSON NOT NULL, + ) PRIMARY KEY (session_id, created_at ASC, id)""" + + pg_schema = f"""CREATE TABLE IF NOT EXISTS {self.table_name} ( + id varchar(36) DEFAULT (spanner.generate_uuid()), + created_at SPANNER.COMMIT_TIMESTAMP NOT NULL, + session_id TEXT NOT NULL, + message JSONB NOT NULL, + PRIMARY KEY (session_id, created_at, id) + );""" + + ddl = pg_schema if self.dialect == DatabaseDialect.POSTGRESQL else google_schema + database = self.client.instance(self.instance_id).database(self.database_id) + operation = database.update_ddl([ddl]) + operation.result(OPERATION_TIMEOUT_SECONDS) + return operation + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve the messages from Cloud Spanner""" + place_holder = "$1" if self.dialect == DatabaseDialect.POSTGRESQL else "@p1" + query = f"SELECT message FROM {self.table_name} WHERE session_id = {place_holder} ORDER BY created_at;" + param = {"p1": self.session_id} + param_type = {"p1": param_types.STRING} + + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql( + query, + params=param, + param_types=param_type, + ) + items = [] # type: List[dict] + for row in results: + items.append({"data": row[0], "type": row[0]["type"]}) + messages = messages_from_dict(items) + return messages + + def add_message(self, message: BaseMessage) -> None: + """Append the message to the record in Cloud Spanner""" + with self.database.batch() as batch: + batch.insert( + table=self.table_name, + columns=("session_id", "created_at", "message"), + values=[ + ( + self.session_id, + spanner.COMMIT_TIMESTAMP, + JsonObject(message.dict()), + ), + ], + ) + + def clear(self) -> None: + """Clear session memory from Cloud Spanner""" + place_holder = "$1" if self.dialect == DatabaseDialect.POSTGRESQL else "@p1" + query = f"DELETE FROM {self.table_name} WHERE session_id = {place_holder};" + param = {"p1": self.session_id} + param_type = {"p1": param_types.STRING} + self.database.execute_partitioned_dml(query, param, param_type) diff --git a/tests/integration/test_spanner_chat_message_history.py b/tests/integration/test_spanner_chat_message_history.py new file mode 100644 index 00000000..1e2f56d2 --- /dev/null +++ b/tests/integration/test_spanner_chat_message_history.py @@ -0,0 +1,89 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +import pytest # noqa +from google.cloud.spanner import Client # type: ignore +from langchain_core.messages.ai import AIMessage +from langchain_core.messages.human import HumanMessage + +from langchain_google_spanner import SpannerChatMessageHistory + +project_id = os.environ["PROJECT_ID"] +instance_id = os.environ["INSTANCE_ID"] +table_name = os.environ["TABLE_NAME"].replace("-", "_") + +OPERATION_TIMEOUT_SECONDS = 240 + + +@pytest.fixture(scope="module") +def client() -> Client: + return Client(project=project_id) + + +@pytest.fixture(scope="module") +def setup(client): + for env in ["GOOGLE_DATABASE", "PG_DATABASE"]: + google_schema = f"""CREATE TABLE IF NOT EXISTS {table_name} ( + id STRING(36) DEFAULT (GENERATE_UUID()), + created_at TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true), + session_id STRING(MAX) NOT NULL, + message JSON NOT NULL, + ) PRIMARY KEY (session_id, created_at ASC, id)""" + + pg_schema = f"""CREATE TABLE IF NOT EXISTS {table_name} ( + id varchar(36) DEFAULT (spanner.generate_uuid()), + created_at SPANNER.COMMIT_TIMESTAMP NOT NULL, + session_id TEXT NOT NULL, + message JSONB NOT NULL, + PRIMARY KEY (session_id, created_at, id) + );""" + database_id = os.environ.get(env) + ddl = pg_schema if env == "PG_DATABASE" else google_schema + database = client.instance(instance_id).database(database_id) + operation = database.update_ddl([ddl]) + operation.result(OPERATION_TIMEOUT_SECONDS) + yield + for env in ["GOOGLE_DATABASE", "PG_DATABASE"]: + database_id = os.environ.get(env) + database = client.instance(instance_id).database(database_id) + operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name}"]) + operation.result(OPERATION_TIMEOUT_SECONDS) + + +def test_chat_message_history(setup) -> None: + for env in ["GOOGLE_DATABASE", "PG_DATABASE"]: + database_id = os.environ.get(env) + assert database_id is not None + history = SpannerChatMessageHistory( + instance_id=instance_id, + database_id=database_id, + session_id="test-session", + table_name=table_name, + ) + history.add_user_message("hi!") + history.add_ai_message("whats up?") + messages = history.messages + + # verify messages are correct + assert messages[0].content == "hi!" + assert type(messages[0]) is HumanMessage + assert messages[1].content == "whats up?" + assert type(messages[1]) is AIMessage + + # verify clear() clears message history + history.clear() + assert len(history.messages) == 0