diff --git a/.gitignore b/.gitignore index ab7848f74..c972d59de 100644 --- a/.gitignore +++ b/.gitignore @@ -234,6 +234,6 @@ apps/openwork-memos-integration/apps/desktop/public/assets/usecases/ # Outputs and Evaluation Results outputs -evaluation/data/temporal_locomo +evaluation/data/ test_add_pipeline.py test_file_pipeline.py diff --git a/Makefile b/Makefile index 788504a73..eb22e241d 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ pre_commit: poetry run pre-commit run -a serve: - poetry run uvicorn memos.api.start_api:app + poetry run uvicorn memos.api.server_api:app openapi: poetry run memos export_openapi --output docs/openapi.json diff --git a/README.md b/README.md index a7b05d683..82ec02772 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ - [**72% lower token usage**](https://x.com/MemOS_dev/status/2020854044583924111) โ€” intelligent memory retrieval instead of loading full chat history - [**Multi-agent memory sharing**](https://x.com/MemOS_dev/status/2020538135487062094) โ€” multi-instance agents share memory via same user_id, automatic context handoff -Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/) +Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/) Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin) ### ๐Ÿง  Local Plugin โ€” 100% On-Device Memory @@ -84,7 +84,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem - **Hybrid search + task & skill evolution** โ€” FTS5 + vector search, auto task summarization, reusable skills that self-upgrade - **Multi-agent collaboration + Memory Viewer** โ€” memory isolation, skill sharing, full web dashboard with 7 management pages - ๐ŸŒ [Homepage](https://memos-claw.openmem.net) ยท + ๐ŸŒ [Homepage](https://memos-claw.openmem.net) ยท ๐Ÿ“– [Documentation](https://memos-claw.openmem.net/docs/index.html) ยท ๐Ÿ“ฆ [NPM](https://www.npmjs.com/package/@memtensor/memos-local-openclaw-plugin) ## ๐Ÿ“Œ MemOS: Memory Operating System for AI Agents @@ -104,10 +104,10 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem ### News -- **2026-03-08** ยท ๐Ÿฆž **MemOS OpenClaw Plugin โ€” Cloud & Local** +- **2026-03-08** ยท ๐Ÿฆž **MemOS OpenClaw Plugin โ€” Cloud & Local** Official OpenClaw memory plugins launched. **Cloud Plugin**: hosted memory service with 72% lower token usage and multi-agent memory sharing ([MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin)). **Local Plugin** (`v1.0.0`): 100% on-device memory with persistent SQLite, hybrid search (FTS5 + vector), task summarization & skill evolution, multi-agent collaboration, and a full Memory Viewer dashboard. -- **2025-12-24** ยท ๐ŸŽ‰ **MemOS v2.0: Stardust (ๆ˜Ÿๅฐ˜) Release** +- **2025-12-24** ยท ๐ŸŽ‰ **MemOS v2.0: Stardust (ๆ˜Ÿๅฐ˜) Release** Comprehensive KB (doc/URL parsing + cross-project sharing), memory feedback & precise deletion, multi-modal memory (images/charts), tool memory for agent planning, Redis Streams scheduling + DB optimizations, streaming/non-streaming chat, MCP upgrade, and lightweight quick/full deployment.
โœจ New Features @@ -155,7 +155,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
- **2025-08-07** ยท ๐ŸŽ‰ **MemOS v1.0.0 (MemCube) Release** - First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, NebulaGraph support, improved search capabilities, and the official Playground launch. + First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, improved search capabilities, and the official Playground launch.
โœจ New Features @@ -176,7 +176,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem **Plaintext Memory** - Integrated internet search with Bocha. - - Added support for Nebula database. + - Expanded graph database support. - Added contextual understanding for the tree-structured plaintext memory search interface.
@@ -188,7 +188,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem - Fixed the concat_cache method. **Plaintext Memory** - - Fixed Nebula search-related issues. + - Fixed graph search-related issues. @@ -224,6 +224,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem 2. Configure `docker/.env.example` and copy to `MemOS/.env` - The `OPENAI_API_KEY`,`MOS_EMBEDDER_API_KEY`,`MEMRADER_API_KEY` and others can be applied for through [`BaiLian`](https://bailian.console.aliyun.com/?spm=a2c4g.11186623.0.0.2f2165b08fRk4l&tab=api#/api). - Fill in the corresponding configuration in the `MemOS/.env` file. + - Supported LLM providers: **OpenAI**, **Azure OpenAI**, **Qwen (DashScope)**, **DeepSeek**, **MiniMax**, **Ollama**, **HuggingFace**, **vLLM**. Set `MOS_CHAT_MODEL_PROVIDER` to select the backend (e.g., `openai`, `qwen`, `deepseek`, `minimax`). 3. Start the service. - Launch via Docker diff --git a/apps/memos-local-openclaw/.env.example b/apps/memos-local-openclaw/.env.example index bfb409298..ce292ba80 100644 --- a/apps/memos-local-openclaw/.env.example +++ b/apps/memos-local-openclaw/.env.example @@ -18,6 +18,10 @@ SUMMARIZER_TEMPERATURE=0 # Port for the web-based Memory Viewer (default: 18799) # VIEWER_PORT=18799 +# โ”€โ”€โ”€ Tavily Search (optional) โ”€โ”€โ”€ +# API key for Tavily web search (get from https://app.tavily.com) +# TAVILY_API_KEY=tvly-your-tavily-api-key + # โ”€โ”€โ”€ Telemetry (opt-out) โ”€โ”€โ”€ # Anonymous usage analytics to help improve the plugin. # No memory content, queries, or personal data is ever sent โ€” only tool names, latencies, and version info. diff --git a/docker/.env.example b/docker/.env.example index 3674cd69b..c9d8e714e 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -25,9 +25,12 @@ MOS_MAX_TOKENS=2048 # Top-P for LLM in the Product API MOS_TOP_P=0.9 # LLM for the Product API backend -MOS_CHAT_MODEL_PROVIDER=openai # openai | huggingface | vllm +MOS_CHAT_MODEL_PROVIDER=openai # openai | huggingface | vllm | minimax OPENAI_API_KEY=sk-xxx # [required] when provider=openai OPENAI_API_BASE=https://api.openai.com/v1 # [required] base for the key +# MiniMax LLM (when provider=minimax) +# MINIMAX_API_KEY=your-minimax-api-key # [required] when provider=minimax +# MINIMAX_API_BASE=https://api.minimax.io/v1 # base for MiniMax API ## MemReader / retrieval LLM MEMRADER_MODEL=gpt-4o-mini @@ -80,8 +83,12 @@ EMBEDDING_MODEL=nomic-embed-text:latest ## Internet search & preference memory # Enable web search ENABLE_INTERNET=false +# Internet search backend (bocha | tavily) +INTERNET_SEARCH_BACKEND=bocha # API key for BOCHA Search -BOCHA_API_KEY= # required if ENABLE_INTERNET=true +BOCHA_API_KEY= # required if ENABLE_INTERNET=true and backend=bocha +# API key for Tavily Search +TAVILY_API_KEY= # required if ENABLE_INTERNET=true and backend=tavily # default search mode SEARCH_MODE=fast # fast | fine | mixture # Slow retrieval strategy configuration, rewrite is the rewrite strategy @@ -127,7 +134,7 @@ MEMSCHEDULER_USE_REDIS_QUEUE=false ## Graph / vector stores # Neo4j database selection mode -NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | nebular | polardb +NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | polardb | postgres # Neo4j database url NEO4J_URI=bolt://localhost:7687 # required when backend=neo4j* # Neo4j database user diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 0a8e2c634..6805ec781 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -26,7 +26,7 @@ services: - memos_network neo4j: - image: neo4j:5.26.4 + image: neo4j:5.26.6 container_name: neo4j-docker ports: - "7474:7474" # HTTP diff --git a/docker/requirements-full.txt b/docker/requirements-full.txt index a14257a76..b148d43d6 100644 --- a/docker/requirements-full.txt +++ b/docker/requirements-full.txt @@ -185,3 +185,4 @@ py-key-value-shared==0.2.8 PyJWT==2.10.1 pytest==9.0.2 alibabacloud-oss-v2==1.2.2 +tavily-python==0.5.0 diff --git a/docker/requirements.txt b/docker/requirements.txt index 340f4e140..988e64b83 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -124,3 +124,4 @@ uvloop==0.22.1; sys_platform != 'win32' watchfiles==1.1.1 websockets==15.0.1 alibabacloud-oss-v2==1.2.2 +tavily-python==0.5.0 diff --git a/examples/basic_modules/llm.py b/examples/basic_modules/llm.py index fb157c991..3fd7352c7 100644 --- a/examples/basic_modules/llm.py +++ b/examples/basic_modules/llm.py @@ -164,7 +164,37 @@ print("Scenario 6:", resp) -# Scenario 7: Using LLMFactory with Deepseek-chat + reasoning + CoT + streaming +# Scenario 7: Using LLMFactory with MiniMax (OpenAI-compatible API) +# Prerequisites: +# 1. Get your API key from the MiniMax platform. +# 2. Available models: MiniMax-M2.7 (flagship), MiniMax-M2.7-highspeed (low-latency), +# MiniMax-M2.5, MiniMax-M2.5-highspeed. + +cfg_mm = LLMConfigFactory.model_validate( + { + "backend": "minimax", + "config": { + "model_name_or_path": "MiniMax-M2.7", + "api_key": "your-minimax-api-key", + "api_base": "https://api.minimax.io/v1", + "temperature": 0.7, + "max_tokens": 1024, + }, + } +) +llm = LLMFactory.from_config(cfg_mm) +messages = [{"role": "user", "content": "Hello, who are you"}] +resp = llm.generate(messages) +print("Scenario 7:", resp) +print("==" * 20) + +print("Scenario 7 (streaming):\n") +for chunk in llm.generate_stream(messages): + print(chunk, end="") +print("\n" + "==" * 20) + + +# Scenario 8: Using LLMFactory with DeepSeek-chat + reasoning + CoT + streaming cfg2 = LLMConfigFactory.model_validate( { @@ -186,7 +216,7 @@ "content": "Explain how to solve this problem step-by-step. Be explicit in your thinking process. Question: If a train travels from city A to city B at 60 mph and returns at 40 mph, what is its average speed for the entire trip? Let's think step by step.", }, ] -print("Scenario 7:\n") +print("Scenario 8:\n") for chunk in llm.generate_stream(messages): print(chunk, end="") print("==" * 20) diff --git a/examples/basic_modules/neo4j_example.py b/examples/basic_modules/neo4j_example.py index e1c0df317..2a5e88749 100644 --- a/examples/basic_modules/neo4j_example.py +++ b/examples/basic_modules/neo4j_example.py @@ -2,6 +2,8 @@ from datetime import datetime +from dotenv import load_dotenv + from memos.configs.embedder import EmbedderConfigFactory from memos.configs.graph_db import GraphDBConfigFactory from memos.embedders.factory import EmbedderFactory @@ -9,14 +11,27 @@ from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +load_dotenv() + +NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687") +NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") +NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "12345678") +NEO4J_DB_NAME = os.getenv("NEO4J_DB_NAME", "neo4j") +EMBEDDING_DIMENSION = int(os.getenv("EMBEDDING_DIMENSION", "3072")) + +QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost") +QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) + embedder_config = EmbedderConfigFactory.model_validate( { - "backend": "universal_api", + "backend": os.getenv("MOS_EMBEDDER_BACKEND", "universal_api"), "config": { - "provider": "openai", - "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), - "model_name_or_path": "text-embedding-3-large", - "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + "provider": os.getenv("MOS_EMBEDDER_PROVIDER", "openai"), + "api_key": os.getenv("MOS_EMBEDDER_API_KEY", os.getenv("OPENAI_API_KEY", "")), + "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"), + "base_url": os.getenv( + "MOS_EMBEDDER_API_BASE", os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") + ), }, } ) @@ -31,12 +46,12 @@ def get_neo4j_graph(db_name: str = "paper"): config = GraphDBConfigFactory( backend="neo4j", config={ - "uri": "bolt://xxxx:7687", - "user": "neo4j", - "password": "xxxx", + "uri": NEO4J_URI, + "user": NEO4J_USER, + "password": NEO4J_PASSWORD, "db_name": db_name, "auto_create": True, - "embedding_dimension": 3072, + "embedding_dimension": EMBEDDING_DIMENSION, "use_multi_db": True, }, ) @@ -49,12 +64,12 @@ def example_multi_db(db_name: str = "paper"): config = GraphDBConfigFactory( backend="neo4j", config={ - "uri": "bolt://localhost:7687", - "user": "neo4j", - "password": "12345678", + "uri": NEO4J_URI, + "user": NEO4J_USER, + "password": NEO4J_PASSWORD, "db_name": db_name, "auto_create": True, - "embedding_dimension": 3072, + "embedding_dimension": EMBEDDING_DIMENSION, "use_multi_db": True, }, ) @@ -288,14 +303,14 @@ def example_shared_db(db_name: str = "shared-traval-group"): config = GraphDBConfigFactory( backend="neo4j", config={ - "uri": "bolt://localhost:7687", - "user": "neo4j", - "password": "12345678", + "uri": NEO4J_URI, + "user": NEO4J_USER, + "password": NEO4J_PASSWORD, "db_name": db_name, "user_name": user_name, "use_multi_db": False, "auto_create": True, - "embedding_dimension": 3072, + "embedding_dimension": EMBEDDING_DIMENSION, }, ) # Step 2: Instantiate graph store @@ -353,12 +368,12 @@ def example_shared_db(db_name: str = "shared-traval-group"): config_alice = GraphDBConfigFactory( backend="neo4j", config={ - "uri": "bolt://localhost:7687", - "user": "neo4j", - "password": "12345678", + "uri": NEO4J_URI, + "user": NEO4J_USER, + "password": NEO4J_PASSWORD, "db_name": db_name, "user_name": user_list[0], - "embedding_dimension": 3072, + "embedding_dimension": EMBEDDING_DIMENSION, }, ) graph_alice = GraphStoreFactory.from_config(config_alice) @@ -382,24 +397,22 @@ def run_user_session( config = GraphDBConfigFactory( backend="neo4j-community", config={ - "uri": "bolt://localhost:7687", - "user": "neo4j", - "password": "12345678", + "uri": NEO4J_URI, + "user": NEO4J_USER, + "password": NEO4J_PASSWORD, "db_name": db_name, "user_name": user_name, "use_multi_db": False, - "auto_create": False, # Neo4j Community does not allow auto DB creation - "embedding_dimension": 3072, + "auto_create": False, + "embedding_dimension": EMBEDDING_DIMENSION, "vec_config": { - # Pass nested config to initialize external vector DB - # If you use qdrant, please use Server instead of local mode. "backend": "qdrant", "config": { "collection_name": "neo4j_vec_db", - "vector_dimension": 3072, + "vector_dimension": EMBEDDING_DIMENSION, "distance_metric": "cosine", - "host": "localhost", - "port": 6333, + "host": QDRANT_HOST, + "port": QDRANT_PORT, }, }, }, @@ -408,14 +421,14 @@ def run_user_session( config = GraphDBConfigFactory( backend="neo4j", config={ - "uri": "bolt://localhost:7687", - "user": "neo4j", - "password": "12345678", + "uri": NEO4J_URI, + "user": NEO4J_USER, + "password": NEO4J_PASSWORD, "db_name": db_name, "user_name": user_name, "use_multi_db": False, "auto_create": True, - "embedding_dimension": 3072, + "embedding_dimension": EMBEDDING_DIMENSION, }, ) graph = GraphStoreFactory.from_config(config) diff --git a/examples/basic_modules/textual_memory_internet_search_example.py b/examples/basic_modules/textual_memory_internet_search_example.py index 9007d7e67..48c5b5f72 100644 --- a/examples/basic_modules/textual_memory_internet_search_example.py +++ b/examples/basic_modules/textual_memory_internet_search_example.py @@ -22,6 +22,8 @@ **Prerequisites:** - Valid BochaAI API Key (set in environment variable: BOCHA_API_KEY) +- (Optional) Valid Tavily API Key (set in environment variable: TAVILY_API_KEY) + - Get from https://app.tavily.com (1,000 free credits/month) - (Optional) Valid Google API Key and Search Engine ID for Google Custom Search - GOOGLE_API_KEY: Get from https://console.cloud.google.com/ - GOOGLE_SEARCH_ENGINE_ID: Get from https://programmablesearchengine.google.com/ @@ -288,19 +290,102 @@ print("\n Get your credentials from:") print(" https://developers.google.com/custom-search/v1/overview") +# ============================================================================ +# Step 7: Test Tavily Search API (Optional) +# ============================================================================ +print("\n" + "=" * 80) +print("TAVILY SEARCH API TEST") +print("=" * 80) + +tavily_api_key = os.environ.get("TAVILY_API_KEY", "") + +if tavily_api_key: + print("\n[Step 7.1] Configuring Tavily Search retriever...") + + tavily_retriever_config = InternetRetrieverConfigFactory.model_validate( + { + "backend": "tavily", + "config": { + "api_key": tavily_api_key, + "max_results": 5, + "search_depth": "basic", + "include_answer": False, + }, + } + ) + + print(" Retriever configured: tavily") + print(f" Max results: {tavily_retriever_config.config.max_results}") + + print("\n[Step 7.2] Creating Tavily retriever instance...") + tavily_retriever = InternetRetrieverFactory.from_config(tavily_retriever_config, embedder) + print(" Tavily retriever initialized") + + print("\n[Step 7.3] Performing Tavily web search...") + tavily_query = "latest developments in AI 2024" + print(f" Query: '{tavily_query}'") + print(" Searching via Tavily Search API...\n") + + tavily_results = tavily_retriever.retrieve_from_internet(tavily_query) + + print(f" Tavily search completed! Retrieved {len(tavily_results)} memory items\n") + + print("=" * 80) + print("TAVILY SEARCH RESULTS") + print("=" * 80) + + if not tavily_results: + print("\n No results found from Tavily.") + print(" This might indicate:") + print(" - Invalid Tavily API key") + print(" - Network connectivity issues") + else: + for idx, item in enumerate(tavily_results, 1): + print(f"\n[Tavily Result #{idx}]") + print("-" * 80) + + content = item.memory + if len(content) > 300: + print(f"Content: {content[:300]}...") + print(f" (... {len(content) - 300} more characters)") + else: + print(f"Content: {content}") + + if hasattr(item, "metadata") and item.metadata: + metadata = item.metadata + if hasattr(metadata, "sources") and metadata.sources: + print(f"Source: {metadata.sources[0] if metadata.sources else 'N/A'}") + + print() + + print("=" * 80) + print("Tavily Search Test completed!") + print("=" * 80) +else: + print("\n Skipping Tavily Search API test") + print(" To enable this test, set the following environment variable:") + print(" - TAVILY_API_KEY: Your Tavily API key (get from https://app.tavily.com)") + print("\n" + "=" * 80) print("ALL TESTS COMPLETED") print("=" * 80) -print("\n๐Ÿ’ก Summary:") -print(" โœ“ Tested BochaAI web search retriever") +print("\n Summary:") +print(" - Tested BochaAI web search retriever") if google_api_key and google_search_engine_id: - print(" โœ“ Tested Google Custom Search API") + print(" - Tested Google Custom Search API") else: - print(" โญ๏ธ Skipped Google Custom Search API (credentials not set)") -print("\n๐Ÿ’ก Quick Start:") + print(" - Skipped Google Custom Search API (credentials not set)") +if tavily_api_key: + print(" - Tested Tavily Search API") +else: + print(" - Skipped Tavily Search API (credentials not set)") +print("\n Quick Start:") print(" # Set BochaAI API key") print(" export BOCHA_API_KEY='sk-your-bocha-api-key'") print(" ") +print(" # Set Tavily API key (optional)") +print(" export TAVILY_API_KEY='tvly-your-tavily-api-key'") +print(" ") print(" # Set Google Custom Search credentials (optional)") print(" export GOOGLE_API_KEY='your-google-api-key'") print(" export GOOGLE_SEARCH_ENGINE_ID='your-search-engine-id'") diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index 6dbe202c2..d14b6d687 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -47,7 +47,7 @@ def build_minimal_components(): # Build component configurations using APIConfig methods (like config_builders.py) - # Graph DB configuration - using APIConfig.get_nebular_config() + # Graph DB configuration - using APIConfig graph DB helpers graph_db_backend = os.getenv("NEO4J_BACKEND", "polardb").lower() graph_db_backend_map = { "polardb": APIConfig.get_polardb_config(), diff --git a/examples/mem_mcp/simple_fastmcp_serve.py b/examples/mem_mcp/simple_fastmcp_serve.py index 55ad4d84d..5071314e0 100644 --- a/examples/mem_mcp/simple_fastmcp_serve.py +++ b/examples/mem_mcp/simple_fastmcp_serve.py @@ -23,7 +23,7 @@ def add_memory(memory_content: str, user_id: str, cube_id: str | None = None): """Add memory using the Server API.""" payload = { "user_id": user_id, - "messages": memory_content, + "messages": [{"role": "user", "content": memory_content}], "writable_cube_ids": [cube_id] if cube_id else None, } try: diff --git a/poetry.lock b/poetry.lock index 72049f025..dfea31354 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.3.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -1897,7 +1897,7 @@ files = [ [package.dependencies] attrs = ">=22.2.0" -jsonschema-specifications = ">=2023.3.6" +jsonschema-specifications = ">=2023.03.6" referencing = ">=0.28.4" rpds-py = ">=0.7.1" @@ -2672,7 +2672,6 @@ files = [ {file = "nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1"}, {file = "nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868"}, ] -markers = {main = "extra == \"all\""} [package.dependencies] click = "*" @@ -5476,6 +5475,24 @@ mpmath = ">=1.1.0,<1.4" [package.extras] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] +[[package]] +name = "tavily-python" +version = "0.7.23" +description = "Python wrapper for the Tavily API" +optional = true +python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"tavily\" or extra == \"all\"" +files = [ + {file = "tavily_python-0.7.23-py3-none-any.whl", hash = "sha256:52ef85c44b926bce3f257570cd32bc1bd4db54666acf3105617f27411a59e188"}, + {file = "tavily_python-0.7.23.tar.gz", hash = "sha256:3b92232e0e29ab68898b765f281bb4f2c650b02210b64affbc48e15292e96161"}, +] + +[package.dependencies] +httpx = "*" +requests = "*" +tiktoken = ">=0.5.1" + [[package]] name = "tenacity" version = "9.1.2" @@ -5504,6 +5521,81 @@ files = [ {file = "threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e"}, ] +[[package]] +name = "tiktoken" +version = "0.12.0" +description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" +optional = true +python-versions = ">=3.9" +groups = ["main"] +markers = "extra == \"tavily\" or extra == \"all\"" +files = [ + {file = "tiktoken-0.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3de02f5a491cfd179aec916eddb70331814bd6bf764075d39e21d5862e533970"}, + {file = "tiktoken-0.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b6cfb6d9b7b54d20af21a912bfe63a2727d9cfa8fbda642fd8322c70340aad16"}, + {file = "tiktoken-0.12.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:cde24cdb1b8a08368f709124f15b36ab5524aac5fa830cc3fdce9c03d4fb8030"}, + {file = "tiktoken-0.12.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6de0da39f605992649b9cfa6f84071e3f9ef2cec458d08c5feb1b6f0ff62e134"}, + {file = "tiktoken-0.12.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6faa0534e0eefbcafaccb75927a4a380463a2eaa7e26000f0173b920e98b720a"}, + {file = "tiktoken-0.12.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:82991e04fc860afb933efb63957affc7ad54f83e2216fe7d319007dab1ba5892"}, + {file = "tiktoken-0.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:6fb2995b487c2e31acf0a9e17647e3b242235a20832642bb7a9d1a181c0c1bb1"}, + {file = "tiktoken-0.12.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6e227c7f96925003487c33b1b32265fad2fbcec2b7cf4817afb76d416f40f6bb"}, + {file = "tiktoken-0.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c06cf0fcc24c2cb2adb5e185c7082a82cba29c17575e828518c2f11a01f445aa"}, + {file = "tiktoken-0.12.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:f18f249b041851954217e9fd8e5c00b024ab2315ffda5ed77665a05fa91f42dc"}, + {file = "tiktoken-0.12.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:47a5bc270b8c3db00bb46ece01ef34ad050e364b51d406b6f9730b64ac28eded"}, + {file = "tiktoken-0.12.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:508fa71810c0efdcd1b898fda574889ee62852989f7c1667414736bcb2b9a4bd"}, + {file = "tiktoken-0.12.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a1af81a6c44f008cba48494089dd98cccb8b313f55e961a52f5b222d1e507967"}, + {file = "tiktoken-0.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:3e68e3e593637b53e56f7237be560f7a394451cb8c11079755e80ae64b9e6def"}, + {file = "tiktoken-0.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b97f74aca0d78a1ff21b8cd9e9925714c15a9236d6ceacf5c7327c117e6e21e8"}, + {file = "tiktoken-0.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2b90f5ad190a4bb7c3eb30c5fa32e1e182ca1ca79f05e49b448438c3e225a49b"}, + {file = "tiktoken-0.12.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:65b26c7a780e2139e73acc193e5c63ac754021f160df919add909c1492c0fb37"}, + {file = "tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:edde1ec917dfd21c1f2f8046b86348b0f54a2c0547f68149d8600859598769ad"}, + {file = "tiktoken-0.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:35a2f8ddd3824608b3d650a000c1ef71f730d0c56486845705a8248da00f9fe5"}, + {file = "tiktoken-0.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83d16643edb7fa2c99eff2ab7733508aae1eebb03d5dfc46f5565862810f24e3"}, + {file = "tiktoken-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc5288f34a8bc02e1ea7047b8d041104791d2ddbf42d1e5fa07822cbffe16bd"}, + {file = "tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3"}, + {file = "tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160"}, + {file = "tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa"}, + {file = "tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be"}, + {file = "tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a"}, + {file = "tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3"}, + {file = "tiktoken-0.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:b8a0cd0c789a61f31bf44851defbd609e8dd1e2c8589c614cc1060940ef1f697"}, + {file = "tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16"}, + {file = "tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a"}, + {file = "tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27"}, + {file = "tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb"}, + {file = "tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e"}, + {file = "tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25"}, + {file = "tiktoken-0.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:dc2dd125a62cb2b3d858484d6c614d136b5b848976794edfb63688d539b8b93f"}, + {file = "tiktoken-0.12.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a90388128df3b3abeb2bfd1895b0681412a8d7dc644142519e6f0a97c2111646"}, + {file = "tiktoken-0.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:da900aa0ad52247d8794e307d6446bd3cdea8e192769b56276695d34d2c9aa88"}, + {file = "tiktoken-0.12.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:285ba9d73ea0d6171e7f9407039a290ca77efcdb026be7769dccc01d2c8d7fff"}, + {file = "tiktoken-0.12.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:d186a5c60c6a0213f04a7a802264083dea1bbde92a2d4c7069e1a56630aef830"}, + {file = "tiktoken-0.12.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:604831189bd05480f2b885ecd2d1986dc7686f609de48208ebbbddeea071fc0b"}, + {file = "tiktoken-0.12.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8f317e8530bb3a222547b85a58583238c8f74fd7a7408305f9f63246d1a0958b"}, + {file = "tiktoken-0.12.0-cp314-cp314-win_amd64.whl", hash = "sha256:399c3dd672a6406719d84442299a490420b458c44d3ae65516302a99675888f3"}, + {file = "tiktoken-0.12.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2c714c72bc00a38ca969dae79e8266ddec999c7ceccd603cc4f0d04ccd76365"}, + {file = "tiktoken-0.12.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cbb9a3ba275165a2cb0f9a83f5d7025afe6b9d0ab01a22b50f0e74fee2ad253e"}, + {file = "tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:dfdfaa5ffff8993a3af94d1125870b1d27aed7cb97aa7eb8c1cefdbc87dbee63"}, + {file = "tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:584c3ad3d0c74f5269906eb8a659c8bfc6144a52895d9261cdaf90a0ae5f4de0"}, + {file = "tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:54c891b416a0e36b8e2045b12b33dd66fb34a4fe7965565f1b482da50da3e86a"}, + {file = "tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5edb8743b88d5be814b1a8a8854494719080c28faaa1ccbef02e87354fe71ef0"}, + {file = "tiktoken-0.12.0-cp314-cp314t-win_amd64.whl", hash = "sha256:f61c0aea5565ac82e2ec50a05e02a6c44734e91b51c10510b084ea1b8e633a71"}, + {file = "tiktoken-0.12.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:d51d75a5bffbf26f86554d28e78bfb921eae998edc2675650fd04c7e1f0cdc1e"}, + {file = "tiktoken-0.12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:09eb4eae62ae7e4c62364d9ec3a57c62eea707ac9a2b2c5d6bd05de6724ea179"}, + {file = "tiktoken-0.12.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:df37684ace87d10895acb44b7f447d4700349b12197a526da0d4a4149fde074c"}, + {file = "tiktoken-0.12.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:4c9614597ac94bb294544345ad8cf30dac2129c05e2db8dc53e082f355857af7"}, + {file = "tiktoken-0.12.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:20cf97135c9a50de0b157879c3c4accbb29116bcf001283d26e073ff3b345946"}, + {file = "tiktoken-0.12.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:15d875454bbaa3728be39880ddd11a5a2a9e548c29418b41e8fd8a767172b5ec"}, + {file = "tiktoken-0.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:2cff3688ba3c639ebe816f8d58ffbbb0aa7433e23e08ab1cade5d175fc973fb3"}, + {file = "tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931"}, +] + +[package.dependencies] +regex = ">=2022.1.18" +requests = ">=2.26.0" + +[package.extras] +blobfile = ["blobfile (>=2)"] + [[package]] name = "tokenizers" version = "0.21.2" @@ -6544,15 +6636,16 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["alibabacloud-oss-v2", "cachetools", "chonkie", "datasketch", "jieba", "langchain-text-splitters", "markitdown", "neo4j", "nltk", "pika", "pymilvus", "pymysql", "qdrant-client", "rake-nltk", "rank-bm25", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +all = ["alibabacloud-oss-v2", "cachetools", "chonkie", "datasketch", "jieba", "langchain-text-splitters", "markitdown", "neo4j", "nltk", "pika", "pymilvus", "pymysql", "qdrant-client", "rake-nltk", "rank-bm25", "redis", "schedule", "sentence-transformers", "tavily-python", "torch", "volcengine-python-sdk"] mem-reader = ["chonkie", "langchain-text-splitters", "markitdown"] mem-scheduler = ["pika", "redis"] mem-user = ["pymysql"] pref-mem = ["datasketch", "pymilvus"] skill-mem = ["alibabacloud-oss-v2"] +tavily = ["tavily-python"] tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "e0427aa672e57215033fe964847474521abf61b0a63f443744a6ec0b8c5ff2e2" +content-hash = "86c718082278e9df7b994b87eb4b75c64ee4b5353954bf16dbec371d8264ce0a" diff --git a/pyproject.toml b/pyproject.toml index de8e66ad1..a359ee498 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,11 @@ skill-mem = [ "alibabacloud-oss-v2 (>=1.2.2,<1.2.3)", ] +# Tavily Search +tavily = [ + "tavily-python (>=0.5.0,<1.0.0)", +] + # All optional dependencies # Allow users to install with `pip install MemoryOS[all]` all = [ @@ -129,6 +134,7 @@ all = [ "nltk (>=3.9.1,<4.0.0)", "rake-nltk (>=1.0.6,<1.1.0)", "alibabacloud-oss-v2 (>=1.2.2,<1.2.3)", + "tavily-python (>=0.5.0,<1.0.0)", # Uncategorized dependencies ] diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 87f1efd8e..c68deae5a 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -284,6 +284,20 @@ def qwen_config() -> dict[str, Any]: "remove_think_prefix": True, } + @staticmethod + def minimax_config() -> dict[str, Any]: + """Get MiniMax configuration.""" + return { + "model_name_or_path": os.getenv("MOS_CHAT_MODEL", "MiniMax-M2.7"), + "temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.8")), + "max_tokens": int(os.getenv("MOS_MAX_TOKENS", "8000")), + "top_p": float(os.getenv("MOS_TOP_P", "0.9")), + "top_k": int(os.getenv("MOS_TOP_K", "50")), + "remove_think_prefix": True, + "api_key": os.getenv("MINIMAX_API_KEY", "your-api-key-here"), + "api_base": os.getenv("MINIMAX_API_BASE", "https://api.minimax.io/v1"), + } + @staticmethod def vllm_config() -> dict[str, Any]: """Get Qwen configuration.""" @@ -626,7 +640,26 @@ def get_oss_config() -> dict[str, Any] | None: return config def get_internet_config() -> dict[str, Any]: - """Get embedder configuration.""" + """Get internet retriever configuration. + + Supports backends: bocha (default), tavily, google, bing, xinyu. + Set INTERNET_SEARCH_BACKEND env var to choose the backend. + For Tavily, set TAVILY_API_KEY env var. + For Bocha, set BOCHA_API_KEY env var. + """ + backend = os.getenv("INTERNET_SEARCH_BACKEND", "bocha").lower() + + if backend == "tavily": + return { + "backend": "tavily", + "config": { + "api_key": os.getenv("TAVILY_API_KEY", ""), + "max_results": 10, + "search_depth": os.getenv("TAVILY_SEARCH_DEPTH", "basic"), + "include_answer": os.getenv("TAVILY_INCLUDE_ANSWER", "false").lower() == "true", + }, + } + reader_config = APIConfig.get_reader_config() return { "backend": "bocha", @@ -741,21 +774,6 @@ def get_neo4j_shared_config(user_id: str | None = None) -> dict[str, Any]: "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), } - @staticmethod - def get_nebular_config(user_id: str | None = None) -> dict[str, Any]: - """Get Nebular configuration.""" - return { - "uri": json.loads(os.getenv("NEBULAR_HOSTS", '["localhost"]')), - "user": os.getenv("NEBULAR_USER", "root"), - "password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"), - "space": os.getenv("NEBULAR_SPACE", "shared-tree-textual-memory"), - "user_name": f"memos{user_id.replace('-', '')}", - "use_multi_db": False, - "auto_create": True, - "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), - } - - @staticmethod def get_milvus_config(): return { "collection_name": [ @@ -916,12 +934,14 @@ def get_product_default_config() -> dict[str, Any]: openai_config = APIConfig.get_openai_config() qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() + minimax_config = APIConfig.minimax_config() reader_config = APIConfig.get_reader_config() backend_model = { "openai": openai_config, "huggingface": qwen_config, "vllm": vllm_config, + "minimax": minimax_config, } backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai") mysql_config = APIConfig.get_mysql_config() @@ -1039,6 +1059,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene openai_config = APIConfig.get_openai_config() qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() + minimax_config = APIConfig.minimax_config() mysql_config = APIConfig.get_mysql_config() reader_config = APIConfig.get_reader_config() backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai") @@ -1046,6 +1067,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene "openai": openai_config, "huggingface": qwen_config, "vllm": vllm_config, + "minimax": minimax_config, } # Create MOSConfig config_dict = { @@ -1103,7 +1125,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene neo4j_community_config = APIConfig.get_neo4j_community_config(user_id) neo4j_config = APIConfig.get_neo4j_config(user_id) - nebular_config = APIConfig.get_nebular_config(user_id) polardb_config = APIConfig.get_polardb_config(user_id) internet_config = ( APIConfig.get_internet_config() @@ -1114,7 +1135,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, - "nebular": nebular_config, "polardb": polardb_config, "postgres": postgres_config, } @@ -1144,9 +1164,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene "reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower() == "true", "memory_size": { - "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)), - "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)), - "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)), + "WorkingMemory": int(os.getenv("MOS_WORKING_MEMORY", 20)), + "LongTermMemory": int(os.getenv("MOS_LONGTERM_MEMORY", 1e6)), + "UserMemory": int(os.getenv("MOS_USER_MEMORY", 1e6)), }, "search_strategy": { "fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"), @@ -1169,7 +1189,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene } ) else: - raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}") + raise ValueError(f"Invalid graph DB backend: {graph_db_backend}") default_mem_cube = GeneralMemCube(default_cube_config) return default_config, default_mem_cube @@ -1188,13 +1208,11 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": openai_config = APIConfig.get_openai_config() neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default") neo4j_config = APIConfig.get_neo4j_config(user_id="default") - nebular_config = APIConfig.get_nebular_config(user_id="default") polardb_config = APIConfig.get_polardb_config(user_id="default") postgres_config = APIConfig.get_postgres_config(user_id="default") graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, - "nebular": nebular_config, "polardb": polardb_config, "postgres": postgres_config, } @@ -1227,9 +1245,9 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": == "true", "internet_retriever": internet_config, "memory_size": { - "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)), - "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)), - "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)), + "WorkingMemory": int(os.getenv("MOS_WORKING_MEMORY", 20)), + "LongTermMemory": int(os.getenv("MOS_LONGTERM_MEMORY", 1e6)), + "UserMemory": int(os.getenv("MOS_USER_MEMORY", 1e6)), }, "search_strategy": { "fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"), @@ -1253,4 +1271,4 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": } ) else: - raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}") + raise ValueError(f"Invalid graph DB backend: {graph_db_backend}") diff --git a/src/memos/api/exceptions.py b/src/memos/api/exceptions.py index 10a14b4d1..15644113c 100644 --- a/src/memos/api/exceptions.py +++ b/src/memos/api/exceptions.py @@ -14,13 +14,26 @@ class APIExceptionHandler: @staticmethod async def validation_error_handler(request: Request, exc: RequestValidationError): """Handle request validation errors.""" - logger.error(f"Validation error: {exc.errors()}") + errors = exc.errors() + path = request.url.path + method = request.method + + readable_errors = [] + for err in errors: + loc = " -> ".join(str(loc_i) for loc_i in err.get("loc", [])) + readable_errors.append( + f"[{loc}] {err.get('msg', 'unknown error')} (type: {err.get('type', 'unknown')})" + ) + + logger.error( + f"Validation error on {method} {path}: {readable_errors}, raw errors: {errors}" + ) return JSONResponse( status_code=422, content={ "code": 422, - "message": "Parameter validation error", - "detail": exc.errors(), + "message": f"Parameter validation error on {method} {path}: {'; '.join(readable_errors)}", + "detail": errors, "data": None, }, ) diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index e071eacb3..eb982fd06 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -36,7 +36,6 @@ def __init__( vector_db: Any | None = None, internet_retriever: Any | None = None, memory_manager: Any | None = None, - mos_server: Any | None = None, feedback_server: Any | None = None, **kwargs, ): @@ -54,7 +53,6 @@ def __init__( vector_db: Vector database instance internet_retriever: Internet retriever instance memory_manager: Memory manager instance - mos_server: MOS server instance **kwargs: Additional dependencies """ self.llm = llm @@ -68,7 +66,6 @@ def __init__( self.vector_db = vector_db self.internet_retriever = internet_retriever self.memory_manager = memory_manager - self.mos_server = mos_server self.feedback_server = feedback_server # Store any additional dependencies @@ -158,11 +155,6 @@ def vector_db(self): """Get vector database instance.""" return self.deps.vector_db - @property - def mos_server(self): - """Get MOS server instance.""" - return self.deps.mos_server - @property def deepsearch_agent(self): """Get deepsearch agent instance.""" diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 7894ff7dc..a01fffef8 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -28,7 +28,6 @@ from memos.log import get_logger from memos.mem_cube.navie import NaiveMemCube from memos.mem_feedback.simple_feedback import SimpleMemFeedback -from memos.mem_os.product_server import MOSServer from memos.mem_reader.factory import MemReaderFactory from memos.mem_scheduler.orm_modules.base_model import BaseDBManager from memos.mem_scheduler.scheduler_factory import SchedulerFactory @@ -211,15 +210,6 @@ def init_server() -> dict[str, Any]: logger.debug("Text memory initialized") - # Initialize MOS Server - mos_server = MOSServer( - mem_reader=mem_reader, - llm=llm, - online_bot=False, - ) - - logger.debug("MOS server initialized") - # Create MemCube with pre-initialized memory instances naive_mem_cube = NaiveMemCube( text_mem=text_mem, @@ -304,7 +294,6 @@ def init_server() -> dict[str, Any]: "internet_retriever": internet_retriever, "memory_manager": memory_manager, "default_cube_config": default_cube_config, - "mos_server": mos_server, "mem_scheduler": mem_scheduler, "naive_mem_cube": naive_mem_cube, "searcher": searcher, diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index 5655bf1e5..d29429fc9 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -39,13 +39,14 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: graph_db_backend_map = { "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), "neo4j": APIConfig.get_neo4j_config(user_id=user_id), - "nebular": APIConfig.get_nebular_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), "postgres": APIConfig.get_postgres_config(user_id=user_id), } # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars - graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower() + graph_db_backend = os.getenv( + "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community") + ).lower() return GraphDBConfigFactory.model_validate( { "backend": graph_db_backend, diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 72c6a3da8..dc052fdbe 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -314,26 +314,99 @@ def handle_get_memories( return GetMemoryResponse(message="Memories retrieved successfully", data=filtered_results) +def _build_quick_delete_constraints(delete_mem_req: DeleteMemoryRequest) -> dict[str, Any]: + """Build fast-delete constraints from request-level fields.""" + constraints: dict[str, Any] = {} + if delete_mem_req.user_id is not None: + constraints["user_id"] = delete_mem_req.user_id + if delete_mem_req.session_id is not None: + constraints["session_id"] = delete_mem_req.session_id + return constraints + + +def _merge_delete_filter( + base_filter: dict[str, Any] | None, + constraints: dict[str, Any], +) -> dict[str, Any]: + """Merge user/session constraints into an existing filter.""" + if not constraints: + return base_filter or {} + if base_filter is None: + return {"and": [constraints.copy()]} + + if not base_filter: + return {"and": [constraints.copy()]} + + if "and" in base_filter: + and_conditions = base_filter.get("and") + if not isinstance(and_conditions, list): + raise ValueError("Invalid filter format: 'and' must be a list") + return {"and": [*and_conditions, constraints.copy()]} + + if "or" in base_filter: + or_conditions = base_filter.get("or") + if not isinstance(or_conditions, list): + raise ValueError("Invalid filter format: 'or' must be a list") + + merged_or_conditions: list[dict[str, Any]] = [] + for condition in or_conditions: + if not isinstance(condition, dict): + raise ValueError("Invalid filter format: each 'or' condition must be a dict") + merged_condition = condition.copy() + for key, value in constraints.items(): + if key in merged_condition and merged_condition[key] != value: + raise ValueError( + f"Conflicting filter condition for '{key}'. " + "Please merge it manually into request.filter." + ) + merged_condition[key] = value + merged_or_conditions.append(merged_condition) + + return {"or": merged_or_conditions} + + # For plain dict filters, keep strict AND semantics explicitly. + return {"and": [base_filter.copy(), constraints.copy()]} + + def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: NaiveMemCube): """ Handler for deleting memories. Now unified to delete from text_mem only (includes preferences). """ logger.info( - "[Delete memory request] writable_cube_ids: %s, memory_ids: %s, auto_cleanup_working: %s", + "[Delete memory request] writable_cube_ids: %s, memory_ids: %s, file_ids: %s, auto_cleanup_working: %s" + "has_filter: %s, user_id: %s, session_id: %s", delete_mem_req.writable_cube_ids, delete_mem_req.memory_ids, + delete_mem_req.file_ids, + delete_mem_req.filter is not None, + delete_mem_req.user_id, + delete_mem_req.session_id, getattr(delete_mem_req, "auto_cleanup_working", False), ) - # Validate that only one of memory_ids, file_ids, or filter is provided + quick_constraints = _build_quick_delete_constraints(delete_mem_req) + has_non_empty_filter = bool(delete_mem_req.filter) + has_filter_mode = has_non_empty_filter or bool(quick_constraints) + + # Reject empty filter dict when no quick constraints are provided. + if delete_mem_req.filter is not None and not has_non_empty_filter and not quick_constraints: + return DeleteMemoryResponse( + message="filter cannot be empty. Provide a non-empty filter or user_id/session_id.", + data={"status": "failure"}, + ) + + # Validate that only one mode is provided: memory_ids, file_ids, or filter-mode. provided_params = [ delete_mem_req.memory_ids is not None, delete_mem_req.file_ids is not None, - delete_mem_req.filter is not None, + has_filter_mode, ] if sum(provided_params) != 1: return DeleteMemoryResponse( - message="Exactly one of memory_ids, file_ids, or filter must be provided", + message=( + "Exactly one delete mode must be provided: " + "memory_ids, file_ids, or filter/user_id/session_id." + ), data={"status": "failure"}, ) @@ -370,8 +443,14 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: naive_mem_cube.text_mem.delete_by_filter( writable_cube_ids=delete_mem_req.writable_cube_ids, file_ids=delete_mem_req.file_ids ) - elif delete_mem_req.filter is not None: - naive_mem_cube.text_mem.delete_by_filter(filter=delete_mem_req.filter) + elif has_filter_mode: + merged_filter = _merge_delete_filter(delete_mem_req.filter, quick_constraints) + naive_mem_cube.text_mem.delete_by_filter( + writable_cube_ids=delete_mem_req.writable_cube_ids, + filter=merged_filter, + ) + if naive_mem_cube.pref_mem is not None: + naive_mem_cube.pref_mem.delete_by_filter(filter=merged_filter) # After main deletion, optionally clean up related WorkingMemory nodes. if working_ids_to_delete: diff --git a/src/memos/api/product_api.py b/src/memos/api/product_api.py deleted file mode 100644 index ec5cccae1..000000000 --- a/src/memos/api/product_api.py +++ /dev/null @@ -1,38 +0,0 @@ -import logging - -from fastapi import FastAPI - -from memos.api.exceptions import APIExceptionHandler -from memos.api.middleware.request_context import RequestContextMiddleware -from memos.api.routers.product_router import router as product_router - - -# Configure logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) - -app = FastAPI( - title="MemOS Product REST APIs", - description="A REST API for managing multiple users with MemOS Product.", - version="1.0.1", -) - -app.add_middleware(RequestContextMiddleware, source="product_api") -# Include routers -app.include_router(product_router) - -# Exception handlers -app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler) -app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) - - -if __name__ == "__main__": - import argparse - - import uvicorn - - parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, default=8001) - parser.add_argument("--workers", type=int, default=1) - args = parser.parse_args() - uvicorn.run("memos.api.product_api:app", host="0.0.0.0", port=args.port, workers=args.workers) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 59f7f74c8..78dcfc797 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -854,10 +854,22 @@ class GetMemoryDashboardRequest(GetMemoryRequest): class DeleteMemoryRequest(BaseRequest): """Request model for deleting memories.""" - writable_cube_ids: list[str] = Field(None, description="Writable cube IDs") + writable_cube_ids: list[str] | None = Field(None, description="Writable cube IDs") memory_ids: list[str] | None = Field(None, description="Memory IDs") file_ids: list[str] | None = Field(None, description="File IDs") filter: dict[str, Any] | None = Field(None, description="Filter for the memory") + user_id: str | None = Field( + None, + description="Quick delete condition: remove memories for this user_id.", + ) + session_id: str | None = Field( + None, + description="Quick delete condition: remove memories for this session_id.", + ) + conversation_id: str | None = Field( + None, + description="Alias of session_id for backward compatibility.", + ) auto_cleanup_working: bool | None = Field( False, description=( @@ -866,6 +878,15 @@ class DeleteMemoryRequest(BaseRequest): ), ) + @model_validator(mode="after") + def normalize_session_alias(self) -> "DeleteMemoryRequest": + """Normalize conversation_id to session_id.""" + if self.conversation_id and self.session_id and self.conversation_id != self.session_id: + raise ValueError("conversation_id and session_id must be the same when both are set") + if self.session_id is None and self.conversation_id is not None: + self.session_id = self.conversation_id + return self + class SuggestionRequest(BaseRequest): """Request model for getting suggestion queries.""" diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py deleted file mode 100644 index 609d61124..000000000 --- a/src/memos/api/routers/product_router.py +++ /dev/null @@ -1,477 +0,0 @@ -import json -import time -import traceback - -from fastapi import APIRouter, HTTPException -from fastapi.responses import StreamingResponse - -from memos.api.config import APIConfig -from memos.api.product_models import ( - BaseResponse, - ChatCompleteRequest, - ChatRequest, - GetMemoryPlaygroundRequest, - MemoryCreateRequest, - MemoryResponse, - SearchRequest, - SearchResponse, - SimpleResponse, - SuggestionRequest, - SuggestionResponse, - UserRegisterRequest, - UserRegisterResponse, -) -from memos.configs.mem_os import MOSConfig -from memos.log import get_logger -from memos.mem_os.product import MOSProduct -from memos.memos_tools.notification_service import get_error_bot_function, get_online_bot_function - - -logger = get_logger(__name__) - -router = APIRouter(prefix="/product", tags=["Product API"]) - -# Initialize MOSProduct instance with lazy initialization -MOS_PRODUCT_INSTANCE = None - - -def get_mos_product_instance(): - """Get or create MOSProduct instance.""" - global MOS_PRODUCT_INSTANCE - if MOS_PRODUCT_INSTANCE is None: - default_config = APIConfig.get_product_default_config() - logger.info(f"*********init_default_mos_config********* {default_config}") - from memos.configs.mem_os import MOSConfig - - mos_config = MOSConfig(**default_config) - - # Get default cube config from APIConfig (may be None if disabled) - default_cube_config = APIConfig.get_default_cube_config() - logger.info(f"*********initdefault_cube_config******** {default_cube_config}") - - # Get DingDing bot functions - dingding_enabled = APIConfig.is_dingding_bot_enabled() - online_bot = get_online_bot_function() if dingding_enabled else None - error_bot = get_error_bot_function() if dingding_enabled else None - - MOS_PRODUCT_INSTANCE = MOSProduct( - default_config=mos_config, - default_cube_config=default_cube_config, - online_bot=online_bot, - error_bot=error_bot, - ) - logger.info("MOSProduct instance created successfully with inheritance architecture") - return MOS_PRODUCT_INSTANCE - - -get_mos_product_instance() - - -@router.post("/configure", summary="Configure MOSProduct", response_model=SimpleResponse) -def set_config(config): - """Set MOSProduct configuration.""" - global MOS_PRODUCT_INSTANCE - MOS_PRODUCT_INSTANCE = MOSProduct(default_config=config) - return SimpleResponse(message="Configuration set successfully") - - -@router.post("/users/register", summary="Register a new user", response_model=UserRegisterResponse) -def register_user(user_req: UserRegisterRequest): - """Register a new user with configuration and default cube.""" - try: - # Get configuration for the user - time_start_register = time.time() - user_config, default_mem_cube = APIConfig.create_user_config( - user_name=user_req.user_id, user_id=user_req.user_id - ) - logger.info(f"user_config: {user_config.model_dump(mode='json')}") - logger.info(f"default_mem_cube: {default_mem_cube.config.model_dump(mode='json')}") - logger.info( - f"time register api : create user config time user_id: {user_req.user_id} time is: {time.time() - time_start_register}" - ) - mos_product = get_mos_product_instance() - - # Register user with default config and mem cube - result = mos_product.user_register( - user_id=user_req.user_id, - user_name=user_req.user_name, - interests=user_req.interests, - config=user_config, - default_mem_cube=default_mem_cube, - mem_cube_id=user_req.mem_cube_id, - ) - logger.info( - f"time register api : register time user_id: {user_req.user_id} time is: {time.time() - time_start_register}" - ) - if result["status"] == "success": - return UserRegisterResponse( - message="User registered successfully", - data={"user_id": result["user_id"], "mem_cube_id": result["default_cube_id"]}, - ) - else: - raise HTTPException(status_code=400, detail=result["message"]) - - except Exception as err: - logger.error(f"Failed to register user: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.get( - "/suggestions/{user_id}", summary="Get suggestion queries", response_model=SuggestionResponse -) -def get_suggestion_queries(user_id: str): - """Get suggestion queries for a specific user.""" - try: - mos_product = get_mos_product_instance() - suggestions = mos_product.get_suggestion_query(user_id) - return SuggestionResponse( - message="Suggestions retrieved successfully", data={"query": suggestions} - ) - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to get suggestions: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.post( - "/suggestions", - summary="Get suggestion queries with language", - response_model=SuggestionResponse, -) -def get_suggestion_queries_post(suggestion_req: SuggestionRequest): - """Get suggestion queries for a specific user with language preference.""" - try: - mos_product = get_mos_product_instance() - suggestions = mos_product.get_suggestion_query( - user_id=suggestion_req.user_id, - language=suggestion_req.language, - message=suggestion_req.message, - ) - return SuggestionResponse( - message="Suggestions retrieved successfully", data={"query": suggestions} - ) - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to get suggestions: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse) -def get_all_memories(memory_req: GetMemoryPlaygroundRequest): - """Get all memories for a specific user.""" - try: - mos_product = get_mos_product_instance() - if memory_req.search_query: - result = mos_product.get_subgraph( - user_id=memory_req.user_id, - query=memory_req.search_query, - mem_cube_ids=memory_req.mem_cube_ids, - ) - return MemoryResponse(message="Memories retrieved successfully", data=result) - else: - result = mos_product.get_all( - user_id=memory_req.user_id, - memory_type=memory_req.memory_type, - mem_cube_ids=memory_req.mem_cube_ids, - ) - return MemoryResponse(message="Memories retrieved successfully", data=result) - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to get memories: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.post("/add", summary="add a new memory", response_model=SimpleResponse) -def create_memory(memory_req: MemoryCreateRequest): - """Create a new memory for a specific user.""" - logger.info("DIAGNOSTIC: /product/add endpoint called. This confirms the new code is deployed.") - # Initialize status_tracker outside try block to avoid NameError in except blocks - status_tracker = None - - try: - time_start_add = time.time() - mos_product = get_mos_product_instance() - - # Track task if task_id is provided - item_id: str | None = None - if ( - memory_req.task_id - and hasattr(mos_product, "mem_scheduler") - and mos_product.mem_scheduler - ): - from uuid import uuid4 - - from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker - - item_id = str(uuid4()) # Generate a unique item_id for this submission - - # Get Redis client from scheduler - if ( - hasattr(mos_product.mem_scheduler, "redis_client") - and mos_product.mem_scheduler.redis_client - ): - status_tracker = TaskStatusTracker(mos_product.mem_scheduler.redis_client) - # Submit task with "product_add" type - status_tracker.task_submitted( - task_id=item_id, # Use generated item_id for internal tracking - user_id=memory_req.user_id, - task_type="product_add", - mem_cube_id=memory_req.mem_cube_id or memory_req.user_id, - business_task_id=memory_req.task_id, # Use memory_req.task_id as business_task_id - ) - status_tracker.task_started(item_id, memory_req.user_id) # Use item_id here - - # Execute the add operation - mos_product.add( - user_id=memory_req.user_id, - memory_content=memory_req.memory_content, - messages=memory_req.messages, - doc_path=memory_req.doc_path, - mem_cube_id=memory_req.mem_cube_id, - source=memory_req.source, - user_profile=memory_req.user_profile, - session_id=memory_req.session_id, - task_id=memory_req.task_id, - ) - - # Mark task as completed - if status_tracker and item_id: - status_tracker.task_completed(item_id, memory_req.user_id) - - logger.info( - f"time add api : add time user_id: {memory_req.user_id} time is: {time.time() - time_start_add}" - ) - return SimpleResponse(message="Memory created successfully") - - except ValueError as err: - # Mark task as failed if tracking - if status_tracker and item_id: - status_tracker.task_failed(item_id, memory_req.user_id, str(err)) - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - # Mark task as failed if tracking - if status_tracker and item_id: - status_tracker.task_failed(item_id, memory_req.user_id, str(err)) - logger.error(f"Failed to create memory: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.post("/search", summary="Search memories", response_model=SearchResponse) -def search_memories(search_req: SearchRequest): - """Search memories for a specific user.""" - try: - time_start_search = time.time() - mos_product = get_mos_product_instance() - result = mos_product.search( - query=search_req.query, - user_id=search_req.user_id, - install_cube_ids=[search_req.mem_cube_id] if search_req.mem_cube_id else None, - top_k=search_req.top_k, - session_id=search_req.session_id, - ) - logger.info( - f"time search api : add time user_id: {search_req.user_id} time is: {time.time() - time_start_search}" - ) - return SearchResponse(message="Search completed successfully", data=result) - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to search memories: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.post("/chat", summary="Chat with MemOS") -def chat(chat_req: ChatRequest): - """Chat with MemOS for a specific user. Returns SSE stream.""" - try: - mos_product = get_mos_product_instance() - - def generate_chat_response(): - """Generate chat response as SSE stream.""" - try: - # Directly yield from the generator without async wrapper - yield from mos_product.chat_with_references( - query=chat_req.query, - user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, - history=chat_req.history, - internet_search=chat_req.internet_search, - moscube=chat_req.moscube, - session_id=chat_req.session_id, - ) - - except Exception as e: - logger.error(f"Error in chat stream: {e}") - error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" - yield error_data - - return StreamingResponse( - generate_chat_response(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "text/event-stream", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Headers": "*", - "Access-Control-Allow-Methods": "*", - }, - ) - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to start chat: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.post("/chat/complete", summary="Chat with MemOS (Complete Response)") -def chat_complete(chat_req: ChatCompleteRequest): - """Chat with MemOS for a specific user. Returns complete response (non-streaming).""" - try: - mos_product = get_mos_product_instance() - - # Collect all responses from the generator - content, references = mos_product.chat( - query=chat_req.query, - user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, - history=chat_req.history, - internet_search=chat_req.internet_search, - moscube=chat_req.moscube, - base_prompt=chat_req.base_prompt or chat_req.system_prompt, - # will deprecate base_prompt in the future - top_k=chat_req.top_k, - threshold=chat_req.threshold, - session_id=chat_req.session_id, - ) - - # Return the complete response - return { - "message": "Chat completed successfully", - "data": {"response": content, "references": references}, - } - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to start chat: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.get("/users", summary="List all users", response_model=BaseResponse[list]) -def list_users(): - """List all registered users.""" - try: - mos_product = get_mos_product_instance() - users = mos_product.list_users() - return BaseResponse(message="Users retrieved successfully", data=users) - except Exception as err: - logger.error(f"Failed to list users: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.get("/users/{user_id}", summary="Get user info", response_model=BaseResponse[dict]) -async def get_user_info(user_id: str): - """Get user information including accessible cubes.""" - try: - mos_product = get_mos_product_instance() - user_info = mos_product.get_user_info(user_id) - return BaseResponse(message="User info retrieved successfully", data=user_info) - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to get user info: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.get( - "/configure/{user_id}", summary="Get MOSProduct configuration", response_model=SimpleResponse -) -def get_config(user_id: str): - """Get MOSProduct configuration.""" - global MOS_PRODUCT_INSTANCE - config = MOS_PRODUCT_INSTANCE.default_config - return SimpleResponse(message="Configuration retrieved successfully", data=config) - - -@router.get( - "/users/{user_id}/config", summary="Get user configuration", response_model=BaseResponse[dict] -) -def get_user_config(user_id: str): - """Get user-specific configuration.""" - try: - mos_product = get_mos_product_instance() - config = mos_product.get_user_config(user_id) - if config: - return BaseResponse( - message="User configuration retrieved successfully", - data=config.model_dump(mode="json"), - ) - else: - raise HTTPException( - status_code=404, detail=f"Configuration not found for user {user_id}" - ) - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to get user config: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.put( - "/users/{user_id}/config", summary="Update user configuration", response_model=SimpleResponse -) -def update_user_config(user_id: str, config_data: dict): - """Update user-specific configuration.""" - try: - mos_product = get_mos_product_instance() - - # Create MOSConfig from the provided data - config = MOSConfig(**config_data) - - # Update the configuration - success = mos_product.update_user_config(user_id, config) - if success: - return SimpleResponse(message="User configuration updated successfully") - else: - raise HTTPException(status_code=500, detail="Failed to update user configuration") - - except ValueError as err: - raise HTTPException(status_code=400, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to update user config: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.get( - "/instances/status", summary="Get user configuration status", response_model=BaseResponse[dict] -) -def get_instance_status(): - """Get information about active user configurations in memory.""" - try: - mos_product = get_mos_product_instance() - status_info = mos_product.get_user_instance_info() - return BaseResponse( - message="User configuration status retrieved successfully", data=status_info - ) - except Exception as err: - logger.error(f"Failed to get user configuration status: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.get("/instances/count", summary="Get active user count", response_model=BaseResponse[int]) -def get_active_user_count(): - """Get the number of active user configurations in memory.""" - try: - mos_product = get_mos_product_instance() - count = mos_product.get_active_user_count() - return BaseResponse(message="Active user count retrieved successfully", data=count) - except Exception as err: - logger.error(f"Failed to get active user count: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index 529a709a4..1f1e6ccde 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -29,6 +29,17 @@ # Include routers app.include_router(server_router) + +@app.get("/health") +def health_check(): + """Container and load balancer health endpoint.""" + return { + "status": "healthy", + "service": "memos", + "version": app.version, + } + + # Request validation failed app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler) # Invalid business code parameters diff --git a/src/memos/api/start_api.py b/src/memos/api/start_api.py deleted file mode 100644 index 24a36f017..000000000 --- a/src/memos/api/start_api.py +++ /dev/null @@ -1,433 +0,0 @@ -import logging -import os - -from typing import Any, Generic, TypeVar - -from dotenv import load_dotenv -from fastapi import FastAPI -from fastapi.requests import Request -from fastapi.responses import JSONResponse, RedirectResponse -from pydantic import BaseModel, Field - -from memos.api.middleware.request_context import RequestContextMiddleware -from memos.configs.mem_os import MOSConfig -from memos.mem_os.main import MOS -from memos.mem_user.user_manager import UserManager, UserRole - - -# Configure logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) - -# Load environment variables -load_dotenv(override=True) - -T = TypeVar("T") - -# Default configuration -DEFAULT_CONFIG = { - "user_id": os.getenv("MOS_USER_ID", "default_user"), - "session_id": os.getenv("MOS_SESSION_ID", "default_session"), - "enable_textual_memory": True, - "enable_activation_memory": False, - "top_k": int(os.getenv("MOS_TOP_K", "5")), - "chat_model": { - "backend": os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai"), - "config": { - "model_name_or_path": os.getenv("MOS_CHAT_MODEL", "gpt-3.5-turbo"), - "api_key": os.getenv("OPENAI_API_KEY", "apikey"), - "temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.7")), - "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), - }, - }, -} - -# Initialize MOS instance with lazy initialization -MOS_INSTANCE = None - - -def get_mos_instance(): - """Get or create MOS instance with default user creation.""" - global MOS_INSTANCE - if MOS_INSTANCE is None: - # Create a temporary MOS instance to access user manager - temp_config = MOSConfig(**DEFAULT_CONFIG) - temp_mos = MOS.__new__(MOS) - temp_mos.config = temp_config - temp_mos.user_id = temp_config.user_id - temp_mos.session_id = temp_config.session_id - temp_mos.mem_cubes = {} - temp_mos.chat_llm = None # Will be initialized later - temp_mos.user_manager = UserManager() - - # Create default user if it doesn't exist - if not temp_mos.user_manager.validate_user(temp_config.user_id): - temp_mos.user_manager.create_user( - user_name=temp_config.user_id, role=UserRole.USER, user_id=temp_config.user_id - ) - logger.info(f"Created default user: {temp_config.user_id}") - - # Now create the actual MOS instance - MOS_INSTANCE = MOS(config=temp_config) - - return MOS_INSTANCE - - -app = FastAPI( - title="MemOS REST APIs", - description="A REST API for managing and searching memories using MemOS.", - version="1.0.0", -) - -app.add_middleware(RequestContextMiddleware) - - -class BaseRequest(BaseModel): - """Base model for all requests.""" - - user_id: str | None = Field( - None, description="User ID for the request", json_schema_extra={"example": "user123"} - ) - - -class BaseResponse(BaseModel, Generic[T]): - """Base model for all responses.""" - - code: int = Field(200, description="Response status code", json_schema_extra={"example": 200}) - message: str = Field( - ..., description="Response message", json_schema_extra={"example": "Operation successful"} - ) - data: T | None = Field(None, description="Response data") - - -class Message(BaseModel): - role: str = Field( - ..., - description="Role of the message (user or assistant).", - json_schema_extra={"example": "user"}, - ) - content: str = Field( - ..., - description="Message content.", - json_schema_extra={"example": "Hello, how can I help you?"}, - ) - - -class MemoryCreate(BaseRequest): - messages: list[Message] | None = Field( - None, - description="List of messages to store.", - json_schema_extra={"example": [{"role": "user", "content": "Hello"}]}, - ) - mem_cube_id: str | None = Field( - None, description="ID of the memory cube", json_schema_extra={"example": "cube123"} - ) - memory_content: str | None = Field( - None, - description="Content to store as memory", - json_schema_extra={"example": "This is a memory content"}, - ) - doc_path: str | None = Field( - None, - description="Path to document to store", - json_schema_extra={"example": "/path/to/document.txt"}, - ) - - -class SearchRequest(BaseRequest): - query: str = Field( - ..., - description="Search query.", - json_schema_extra={"example": "How to implement a feature?"}, - ) - install_cube_ids: list[str] | None = Field( - None, - description="List of cube IDs to search in", - json_schema_extra={"example": ["cube123", "cube456"]}, - ) - - -class MemCubeRegister(BaseRequest): - mem_cube_name_or_path: str = Field( - ..., - description="Name or path of the MemCube to register.", - json_schema_extra={"example": "/path/to/cube"}, - ) - mem_cube_id: str | None = Field( - None, description="ID for the MemCube", json_schema_extra={"example": "cube123"} - ) - - -class ChatRequest(BaseRequest): - query: str = Field( - ..., - description="Chat query message.", - json_schema_extra={"example": "What is the latest update?"}, - ) - - -class UserCreate(BaseRequest): - user_name: str | None = Field( - None, description="Name of the user", json_schema_extra={"example": "john_doe"} - ) - role: str = Field("user", description="Role of the user", json_schema_extra={"example": "user"}) - user_id: str = Field(..., description="User ID", json_schema_extra={"example": "user123"}) - - -class CubeShare(BaseRequest): - target_user_id: str = Field( - ..., description="Target user ID to share with", json_schema_extra={"example": "user456"} - ) - - -class SimpleResponse(BaseResponse[None]): - """Simple response model for operations without data return.""" - - -class ConfigResponse(BaseResponse[None]): - """Response model for configuration endpoint.""" - - -class MemoryResponse(BaseResponse[dict]): - """Response model for memory operations.""" - - -class SearchResponse(BaseResponse[dict]): - """Response model for search operations.""" - - -class ChatResponse(BaseResponse[str]): - """Response model for chat operations.""" - - -class UserResponse(BaseResponse[dict]): - """Response model for user operations.""" - - -class UserListResponse(BaseResponse[list]): - """Response model for user list operations.""" - - -@app.post("/configure", summary="Configure MemOS", response_model=ConfigResponse) -async def set_config(config: MOSConfig): - """Set MemOS configuration.""" - global MOS_INSTANCE - - # Create a temporary user manager to check/create default user - temp_user_manager = UserManager() - - # Create default user if it doesn't exist - if not temp_user_manager.validate_user(config.user_id): - temp_user_manager.create_user( - user_name=config.user_id, role=UserRole.USER, user_id=config.user_id - ) - logger.info(f"Created default user: {config.user_id}") - - # Now create the MOS instance - MOS_INSTANCE = MOS(config=config) - return ConfigResponse(message="Configuration set successfully") - - -@app.post("/users", summary="Create a new user", response_model=UserResponse) -async def create_user(user_create: UserCreate): - """Create a new user.""" - mos_instance = get_mos_instance() - role = UserRole(user_create.role) - user_id = mos_instance.create_user( - user_id=user_create.user_id, role=role, user_name=user_create.user_name - ) - return UserResponse(message="User created successfully", data={"user_id": user_id}) - - -@app.get("/users", summary="List all users", response_model=UserListResponse) -async def list_users(): - """List all active users.""" - mos_instance = get_mos_instance() - users = mos_instance.list_users() - return UserListResponse(message="Users retrieved successfully", data=users) - - -@app.get("/users/me", summary="Get current user info", response_model=UserResponse) -async def get_user_info(): - """Get current user information including accessible cubes.""" - mos_instance = get_mos_instance() - user_info = mos_instance.get_user_info() - return UserResponse(message="User info retrieved successfully", data=user_info) - - -@app.post("/mem_cubes", summary="Register a MemCube", response_model=SimpleResponse) -async def register_mem_cube(mem_cube: MemCubeRegister): - """Register a new MemCube.""" - mos_instance = get_mos_instance() - mos_instance.register_mem_cube( - mem_cube_name_or_path=mem_cube.mem_cube_name_or_path, - mem_cube_id=mem_cube.mem_cube_id, - user_id=mem_cube.user_id, - ) - return SimpleResponse(message="MemCube registered successfully") - - -@app.delete( - "/mem_cubes/{mem_cube_id}", summary="Unregister a MemCube", response_model=SimpleResponse -) -async def unregister_mem_cube(mem_cube_id: str, user_id: str | None = None): - """Unregister a MemCube.""" - mos_instance = get_mos_instance() - mos_instance.unregister_mem_cube(mem_cube_id=mem_cube_id, user_id=user_id) - return SimpleResponse(message="MemCube unregistered successfully") - - -@app.post( - "/mem_cubes/{cube_id}/share", - summary="Share a cube with another user", - response_model=SimpleResponse, -) -async def share_cube(cube_id: str, share_request: CubeShare): - """Share a cube with another user.""" - mos_instance = get_mos_instance() - success = mos_instance.share_cube_with_user(cube_id, share_request.target_user_id) - if success: - return SimpleResponse(message="Cube shared successfully") - else: - raise ValueError("Failed to share cube") - - -@app.post("/memories", summary="Create memories", response_model=SimpleResponse) -async def add_memory(memory_create: MemoryCreate): - """Store new memories in a MemCube.""" - if not any([memory_create.messages, memory_create.memory_content, memory_create.doc_path]): - raise ValueError("Either messages, memory_content, or doc_path must be provided") - mos_instance = get_mos_instance() - if memory_create.messages: - messages = [m.model_dump() for m in memory_create.messages] - mos_instance.add( - messages=messages, - mem_cube_id=memory_create.mem_cube_id, - user_id=memory_create.user_id, - ) - elif memory_create.memory_content: - mos_instance.add( - memory_content=memory_create.memory_content, - mem_cube_id=memory_create.mem_cube_id, - user_id=memory_create.user_id, - ) - elif memory_create.doc_path: - mos_instance.add( - doc_path=memory_create.doc_path, - mem_cube_id=memory_create.mem_cube_id, - user_id=memory_create.user_id, - ) - return SimpleResponse(message="Memories added successfully") - - -@app.get("/memories", summary="Get all memories", response_model=MemoryResponse) -async def get_all_memories( - mem_cube_id: str | None = None, - user_id: str | None = None, -): - """Retrieve all memories from a MemCube.""" - mos_instance = get_mos_instance() - result = mos_instance.get_all(mem_cube_id=mem_cube_id, user_id=user_id) - return MemoryResponse(message="Memories retrieved successfully", data=result) - - -@app.get( - "/memories/{mem_cube_id}/{memory_id}", summary="Get a memory", response_model=MemoryResponse -) -async def get_memory(mem_cube_id: str, memory_id: str, user_id: str | None = None): - """Retrieve a specific memory by ID from a MemCube.""" - mos_instance = get_mos_instance() - result = mos_instance.get(mem_cube_id=mem_cube_id, memory_id=memory_id, user_id=user_id) - return MemoryResponse(message="Memory retrieved successfully", data=result) - - -@app.post("/search", summary="Search memories", response_model=SearchResponse) -async def search_memories(search_req: SearchRequest): - """Search for memories across MemCubes.""" - mos_instance = get_mos_instance() - result = mos_instance.search( - query=search_req.query, - user_id=search_req.user_id, - install_cube_ids=search_req.install_cube_ids, - ) - return SearchResponse(message="Search completed successfully", data=result) - - -@app.put( - "/memories/{mem_cube_id}/{memory_id}", summary="Update a memory", response_model=SimpleResponse -) -async def update_memory( - mem_cube_id: str, memory_id: str, updated_memory: dict[str, Any], user_id: str | None = None -): - """Update an existing memory in a MemCube.""" - mos_instance = get_mos_instance() - mos_instance.update( - mem_cube_id=mem_cube_id, - memory_id=memory_id, - text_memory_item=updated_memory, - user_id=user_id, - ) - return SimpleResponse(message="Memory updated successfully") - - -@app.delete( - "/memories/{mem_cube_id}/{memory_id}", summary="Delete a memory", response_model=SimpleResponse -) -async def delete_memory(mem_cube_id: str, memory_id: str, user_id: str | None = None): - """Delete a specific memory from a MemCube.""" - mos_instance = get_mos_instance() - mos_instance.delete(mem_cube_id=mem_cube_id, memory_id=memory_id, user_id=user_id) - return SimpleResponse(message="Memory deleted successfully") - - -@app.delete("/memories/{mem_cube_id}", summary="Delete all memories", response_model=SimpleResponse) -async def delete_all_memories(mem_cube_id: str, user_id: str | None = None): - """Delete all memories from a MemCube.""" - mos_instance = get_mos_instance() - mos_instance.delete_all(mem_cube_id=mem_cube_id, user_id=user_id) - return SimpleResponse(message="All memories deleted successfully") - - -@app.post("/chat", summary="Chat with MemOS", response_model=ChatResponse) -async def chat(chat_req: ChatRequest): - """Chat with the MemOS system.""" - mos_instance = get_mos_instance() - response = mos_instance.chat(query=chat_req.query, user_id=chat_req.user_id) - if response is None: - raise ValueError("No response generated") - return ChatResponse(message="Chat response generated", data=response) - - -@app.get("/", summary="Redirect to the OpenAPI documentation", include_in_schema=False) -async def home(): - """Redirect to the OpenAPI documentation.""" - return RedirectResponse(url="/docs", status_code=307) - - -@app.exception_handler(ValueError) -async def value_error_handler(request: Request, exc: ValueError): - """Handle ValueError exceptions globally.""" - return JSONResponse( - status_code=400, - content={"code": 400, "message": str(exc), "data": None}, - ) - - -@app.exception_handler(Exception) -async def global_exception_handler(request: Request, exc: Exception): - """Handle all unhandled exceptions globally.""" - logger.exception("Unhandled error:") - return JSONResponse( - status_code=500, - content={"code": 500, "message": str(exc), "data": None}, - ) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") - parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") - parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development") - args = parser.parse_args() diff --git a/src/memos/cli.py b/src/memos/cli.py index 092f2d276..2ead5ab29 100644 --- a/src/memos/cli.py +++ b/src/memos/cli.py @@ -11,9 +11,16 @@ from io import BytesIO +def get_openapi_app(): + """Return the FastAPI app used for OpenAPI export.""" + from memos.api.server_api import app + + return app + + def export_openapi(output: str) -> bool: """Export OpenAPI schema to JSON file.""" - from memos.api.server_api import app + app = get_openapi_app() # Create directory if it doesn't exist if os.path.dirname(output): diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 5900d2357..98de09812 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -103,57 +103,6 @@ def validate_community(self): return self -class NebulaGraphDBConfig(BaseGraphDBConfig): - """ - NebulaGraph-specific configuration. - - Key concepts: - - `space`: Equivalent to a database or namespace. All tag/edge/schema live within a space. - - `user_name`: Used for logical tenant isolation if needed. - - `auto_create`: Whether to automatically create the target space if it does not exist. - - Example: - --- - hosts = ["127.0.0.1:9669"] - user = "root" - password = "nebula" - space = "shared_graph" - user_name = "alice" - """ - - space: str = Field( - ..., description="The name of the target NebulaGraph space (like a database)" - ) - user_name: str | None = Field( - default=None, - description="Logical user or tenant ID for data isolation (optional, used in metadata tagging)", - ) - auto_create: bool = Field( - default=False, - description="Whether to auto-create the space if it does not exist", - ) - use_multi_db: bool = Field( - default=True, - description=( - "If True: use Neo4j's multi-database feature for physical isolation; " - "each user typically gets a separate database. " - "If False: use a single shared database with logical isolation by user_name." - ), - ) - max_client: int = Field( - default=1000, - description=("max_client"), - ) - embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding") - - @model_validator(mode="after") - def validate_config(self): - """Validate config.""" - if not self.space: - raise ValueError("`space` must be provided") - return self - - class PolarDBGraphDBConfig(BaseConfig): """ PolarDB-specific configuration. @@ -299,7 +248,6 @@ class GraphDBConfigFactory(BaseModel): backend_to_class: ClassVar[dict[str, Any]] = { "neo4j": Neo4jGraphDBConfig, "neo4j-community": Neo4jCommunityGraphDBConfig, - "nebular": NebulaGraphDBConfig, "polardb": PolarDBGraphDBConfig, "postgres": PostgresGraphDBConfig, } diff --git a/src/memos/configs/internet_retriever.py b/src/memos/configs/internet_retriever.py index 1c5e2b8ad..562cfdd1f 100644 --- a/src/memos/configs/internet_retriever.py +++ b/src/memos/configs/internet_retriever.py @@ -67,6 +67,19 @@ class BochaSearchConfig(BaseInternetRetrieverConfig): ) +class TavilySearchConfig(BaseInternetRetrieverConfig): + """Configuration class for Tavily Search API.""" + + search_engine_id: str | None = Field( + None, description="Not used for Tavily Search (kept for compatibility)" + ) + max_results: int = Field(default=10, description="Maximum number of results to retrieve") + search_depth: str = Field(default="basic", description="Search depth: 'basic' or 'advanced'") + include_answer: bool = Field( + default=False, description="Whether to include an AI-generated answer" + ) + + class InternetRetrieverConfigFactory(BaseConfig): """Factory class for creating internet retriever configurations.""" @@ -82,6 +95,7 @@ class InternetRetrieverConfigFactory(BaseConfig): "bing": BingSearchConfig, "xinyu": XinyuSearchConfig, "bocha": BochaSearchConfig, + "tavily": TavilySearchConfig, } @field_validator("backend") diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py index 11c39b33c..81f7038fa 100644 --- a/src/memos/configs/llm.py +++ b/src/memos/configs/llm.py @@ -72,6 +72,15 @@ class DeepSeekLLMConfig(OpenAILLMConfig): ) +class MinimaxLLMConfig(OpenAILLMConfig): + api_key: str = Field(..., description="API key for MiniMax") + api_base: str = Field( + default="https://api.minimax.io/v1", + description="Base URL for MiniMax OpenAI-compatible API", + ) + extra_body: Any = Field(default=None, description="Extra options for API") + + class AzureLLMConfig(BaseLLMConfig): base_url: str = Field( default="https://api.openai.azure.com/", @@ -146,6 +155,7 @@ class LLMConfigFactory(BaseConfig): "huggingface_singleton": HFLLMConfig, # Add singleton support "qwen": QwenLLMConfig, "deepseek": DeepSeekLLMConfig, + "minimax": MinimaxLLMConfig, "openai_new": OpenAIResponsesLLMConfig, } diff --git a/src/memos/context/context.py b/src/memos/context/context.py index 5c8401732..5347de880 100644 --- a/src/memos/context/context.py +++ b/src/memos/context/context.py @@ -155,7 +155,7 @@ def get_current_user_name() -> str | None: def get_current_source() -> str | None: """ - Get the current request's source (e.g., 'product_api' or 'server_api'). + Get the current request's source (for example, 'server_api'). """ context = _request_context.get() if context: diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py index c207e3190..93b5971ec 100644 --- a/src/memos/graph_dbs/factory.py +++ b/src/memos/graph_dbs/factory.py @@ -2,7 +2,6 @@ from memos.configs.graph_db import GraphDBConfigFactory from memos.graph_dbs.base import BaseGraphDB -from memos.graph_dbs.nebular import NebulaGraphDB from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB from memos.graph_dbs.polardb import PolarDBGraphDB @@ -15,7 +14,6 @@ class GraphStoreFactory(BaseGraphDB): backend_to_class: ClassVar[dict[str, Any]] = { "neo4j": Neo4jGraphDB, "neo4j-community": Neo4jCommunityGraphDB, - "nebular": NebulaGraphDB, "polardb": PolarDBGraphDB, "postgres": PostgresGraphDB, } diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py deleted file mode 100644 index 428d6d09e..000000000 --- a/src/memos/graph_dbs/nebular.py +++ /dev/null @@ -1,1794 +0,0 @@ -import json -import traceback - -from contextlib import suppress -from datetime import datetime -from threading import Lock -from typing import TYPE_CHECKING, Any, ClassVar, Literal - -import numpy as np - -from memos.configs.graph_db import NebulaGraphDBConfig -from memos.dependency import require_python_package -from memos.graph_dbs.base import BaseGraphDB -from memos.log import get_logger -from memos.utils import timed - - -if TYPE_CHECKING: - from nebulagraph_python import ( - NebulaClient, - ) - - -logger = get_logger(__name__) - - -_TRANSIENT_ERR_KEYS = ( - "Session not found", - "Connection not established", - "timeout", - "deadline exceeded", - "Broken pipe", - "EOFError", - "socket closed", - "connection reset", - "connection refused", -) - - -@timed -def _normalize(vec: list[float]) -> list[float]: - v = np.asarray(vec, dtype=np.float32) - norm = np.linalg.norm(v) - return (v / (norm if norm else 1.0)).tolist() - - -@timed -def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: - node_id = item["id"] - memory = item["memory"] - metadata = item.get("metadata", {}) - return node_id, memory, metadata - - -@timed -def _escape_str(value: str) -> str: - out = [] - for ch in value: - code = ord(ch) - if ch == "\\": - out.append("\\\\") - elif ch == '"': - out.append('\\"') - elif ch == "\n": - out.append("\\n") - elif ch == "\r": - out.append("\\r") - elif ch == "\t": - out.append("\\t") - elif ch == "\b": - out.append("\\b") - elif ch == "\f": - out.append("\\f") - elif code < 0x20 or code in (0x2028, 0x2029): - out.append(f"\\u{code:04x}") - else: - out.append(ch) - return "".join(out) - - -@timed -def _format_datetime(value: str | datetime) -> str: - """Ensure datetime is in ISO 8601 format string.""" - if isinstance(value, datetime): - return value.isoformat() - return str(value) - - -@timed -def _normalize_datetime(val): - """ - Normalize datetime to ISO 8601 UTC string with +00:00. - - If val is datetime object -> keep isoformat() (Neo4j) - - If val is string without timezone -> append +00:00 (Nebula) - - Otherwise just str() - """ - if hasattr(val, "isoformat"): - return val.isoformat() - if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")): - return val + "+08:00" - return str(val) - - -class NebulaGraphDB(BaseGraphDB): - """ - NebulaGraph-based implementation of a graph memory store. - """ - - # ====== shared pool cache & refcount ====== - # These are process-local; in a multi-process model each process will - # have its own cache. - _CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {} - _CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {} - _CLIENT_LOCK: ClassVar[Lock] = Lock() - _CLIENT_INIT_DONE: ClassVar[set[str]] = set() - - @staticmethod - def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]: - hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None) - if isinstance(hosts, str): - return [hosts] - return list(hosts or []) - - @staticmethod - def _make_client_key(cfg: NebulaGraphDBConfig) -> str: - hosts = NebulaGraphDB._get_hosts_from_cfg(cfg) - return "|".join( - [ - "nebula-sync", - ",".join(hosts), - str(getattr(cfg, "user", "")), - str(getattr(cfg, "space", "")), - ] - ) - - @classmethod - def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB": - tmp = object.__new__(NebulaGraphDB) - tmp.config = cfg - tmp.db_name = cfg.space - tmp.user_name = None - tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072) - tmp.default_memory_dimension = 3072 - tmp.common_fields = { - "id", - "memory", - "user_name", - "user_id", - "session_id", - "status", - "key", - "confidence", - "tags", - "created_at", - "updated_at", - "memory_type", - "sources", - "source", - "node_type", - "visibility", - "usage", - "background", - } - tmp.base_fields = set(tmp.common_fields) - {"usage"} - tmp.heavy_fields = {"usage"} - tmp.dim_field = ( - f"embedding_{tmp.embedding_dimension}" - if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension) - else "embedding" - ) - tmp.system_db_name = cfg.space - tmp._client = client - tmp._owns_client = False - return tmp - - @classmethod - def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]: - from nebulagraph_python import ( - ConnectionConfig, - NebulaClient, - SessionConfig, - SessionPoolConfig, - ) - - key = cls._make_client_key(cfg) - with cls._CLIENT_LOCK: - client = cls._CLIENT_CACHE.get(key) - if client is None: - # Connection setting - - tmp_client = NebulaClient( - hosts=cfg.uri, - username=cfg.user, - password=cfg.password, - session_config=SessionConfig(graph=None), - session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000), - ) - try: - cls._ensure_space_exists(tmp_client, cfg) - finally: - tmp_client.close() - - conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None) - if conn_conf is None: - conn_conf = ConnectionConfig.from_defults( - cls._get_hosts_from_cfg(cfg), - getattr(cfg, "ssl_param", None), - ) - - sess_conf = SessionConfig(graph=getattr(cfg, "space", None)) - pool_conf = SessionPoolConfig( - size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000 - ) - - client = NebulaClient( - hosts=conn_conf.hosts, - username=cfg.user, - password=cfg.password, - conn_config=conn_conf, - session_config=sess_conf, - session_pool_config=pool_conf, - ) - cls._CLIENT_CACHE[key] = client - cls._CLIENT_REFCOUNT[key] = 0 - logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}") - - cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1 - - if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE: - try: - pass - finally: - pass - - if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE: - with cls._CLIENT_LOCK: - if key not in cls._CLIENT_INIT_DONE: - admin = cls._bootstrap_admin(cfg, client) - try: - admin._ensure_database_exists() - admin._create_basic_property_indexes() - admin._create_vector_index( - dimensions=int( - admin.embedding_dimension or admin.default_memory_dimension - ), - ) - cls._CLIENT_INIT_DONE.add(key) - logger.info("[NebulaGraphDBSync] One-time init done") - except Exception: - logger.exception("[NebulaGraphDBSync] One-time init failed") - - return key, client - - def _refresh_client(self): - """ - refresh NebulaClient: - """ - old_key = getattr(self, "_client_key", None) - if not old_key: - return - - cls = self.__class__ - with cls._CLIENT_LOCK: - try: - if old_key in cls._CLIENT_CACHE: - try: - cls._CLIENT_CACHE[old_key].close() - except Exception as e: - logger.warning(f"[refresh_client] close old client error: {e}") - finally: - cls._CLIENT_CACHE.pop(old_key, None) - finally: - cls._CLIENT_REFCOUNT[old_key] = 0 - - new_key, new_client = cls._get_or_create_shared_client(self.config) - self._client_key = new_key - self._client = new_client - logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}") - - @classmethod - def _release_shared_client(cls, key: str): - with cls._CLIENT_LOCK: - if key not in cls._CLIENT_CACHE: - return - cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1) - if cls._CLIENT_REFCOUNT[key] == 0: - try: - cls._CLIENT_CACHE[key].close() - except Exception as e: - logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}") - finally: - cls._CLIENT_CACHE.pop(key, None) - cls._CLIENT_REFCOUNT.pop(key, None) - logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}") - - @classmethod - def close_all_shared_clients(cls): - with cls._CLIENT_LOCK: - for key, client in list(cls._CLIENT_CACHE.items()): - try: - client.close() - except Exception as e: - logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}") - finally: - logger.info(f"[NebulaGraphDBSync] Closed client key={key}") - cls._CLIENT_CACHE.clear() - cls._CLIENT_REFCOUNT.clear() - - @require_python_package( - import_name="nebulagraph_python", - install_command="pip install nebulagraph-python>=5.1.1", - install_link=".....", - ) - def __init__(self, config: NebulaGraphDBConfig): - """ - NebulaGraph DB client initialization. - - Required config attributes: - - hosts: list[str] like ["host1:port", "host2:port"] - - user: str - - password: str - - db_name: str (optional for basic commands) - - Example config: - { - "hosts": ["xxx.xx.xx.xxx:xxxx"], - "user": "root", - "password": "nebula", - "space": "test" - } - """ - - assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED" - self.config = config - self.db_name = config.space - self.user_name = config.user_name - self.embedding_dimension = config.embedding_dimension - self.default_memory_dimension = 3072 - self.common_fields = { - "id", - "memory", - "user_name", - "user_id", - "session_id", - "status", - "key", - "confidence", - "tags", - "created_at", - "updated_at", - "memory_type", - "sources", - "source", - "node_type", - "visibility", - "usage", - "background", - } - self.base_fields = set(self.common_fields) - {"usage"} - self.heavy_fields = {"usage"} - self.dim_field = ( - f"embedding_{self.embedding_dimension}" - if (str(self.embedding_dimension) != str(self.default_memory_dimension)) - else "embedding" - ) - self.system_db_name = config.space - - # ---- NEW: pool acquisition strategy - # Get or create a shared pool from the class-level cache - self._client_key, self._client = self._get_or_create_shared_client(config) - self._owns_client = True - - logger.info("Connected to NebulaGraph successfully.") - - @timed - def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True): - def _wrap_use_db(q: str) -> str: - if auto_set_db and self.db_name: - return f"USE `{self.db_name}`\n{q}" - return q - - try: - return self._client.execute(_wrap_use_db(gql), timeout=timeout) - - except Exception as e: - emsg = str(e) - if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS): - logger.warning(f"[execute_query] {e!s} โ†’ refreshing session pool and retry once...") - try: - self._refresh_client() - return self._client.execute(_wrap_use_db(gql), timeout=timeout) - except Exception: - logger.exception("[execute_query] retry after refresh failed") - raise - raise - - @timed - def close(self): - """ - Close the connection resource if this instance owns it. - - - If pool was injected (`shared_pool`), do nothing. - - If pool was acquired via shared cache, decrement refcount and close - when the last owner releases it. - """ - if not self._owns_client: - logger.debug("[NebulaGraphDBSync] close() skipped (injected client).") - return - if self._client_key: - self._release_shared_client(self._client_key) - self._client_key = None - self._client = None - - # NOTE: __del__ is best-effort; do not rely on GC order. - def __del__(self): - with suppress(Exception): - self.close() - - @timed - def create_index( - self, - label: str = "Memory", - vector_property: str = "embedding", - dimensions: int = 3072, - index_name: str = "memory_vector_index", - ) -> None: - # Create vector index - self._create_vector_index(label, vector_property, dimensions, index_name) - # Create indexes - self._create_basic_property_indexes() - - @timed - def remove_oldest_memory( - self, memory_type: str, keep_latest: int, user_name: str | None = None - ) -> None: - """ - Remove all WorkingMemory nodes except the latest `keep_latest` entries. - - Args: - memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). - keep_latest (int): Number of latest WorkingMemory entries to keep. - user_name(str): optional user_name. - """ - try: - user_name = user_name if user_name else self.config.user_name - optional_condition = f"AND n.user_name = '{user_name}'" - count = self.count_nodes(memory_type, user_name) - if count > keep_latest: - delete_query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE n.memory_type = '{memory_type}' - {optional_condition} - ORDER BY n.updated_at DESC - OFFSET {int(keep_latest)} - DETACH DELETE n - """ - self.execute_query(delete_query) - except Exception as e: - logger.warning(f"Delete old mem error: {e}") - - @timed - def add_node( - self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None - ) -> None: - """ - Insert or update a Memory node in NebulaGraph. - """ - metadata["user_name"] = user_name if user_name else self.config.user_name - now = datetime.utcnow() - metadata = metadata.copy() - metadata.setdefault("created_at", now) - metadata.setdefault("updated_at", now) - metadata["node_type"] = metadata.pop("type") - metadata["id"] = id - metadata["memory"] = memory - - if "embedding" in metadata and isinstance(metadata["embedding"], list): - assert len(metadata["embedding"]) == self.embedding_dimension, ( - f"input embedding dimension must equal to {self.embedding_dimension}" - ) - embedding = metadata.pop("embedding") - metadata[self.dim_field] = _normalize(embedding) - - metadata = self._metadata_filter(metadata) - properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()) - gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})" - - try: - self.execute_query(gql) - logger.info("insert success") - except Exception as e: - logger.error( - f"Failed to insert vertex {id}: gql: {gql}, {e}\ntrace: {traceback.format_exc()}" - ) - - @timed - def node_not_exist(self, scope: str, user_name: str | None = None) -> int: - user_name = user_name if user_name else self.config.user_name - filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{user_name}"' - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE {filter_clause} - RETURN n.id AS id - LIMIT 1 - """ - - try: - result = self.execute_query(query) - return result.size == 0 - except Exception as e: - logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) - raise - - @timed - def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: - """ - Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present. - """ - user_name = user_name if user_name else self.config.user_name - fields = fields.copy() - set_clauses = [] - for k, v in fields.items(): - set_clauses.append(f"n.{k} = {self._format_value(v, k)}") - - set_clause_str = ",\n ".join(set_clauses) - - query = f""" - MATCH (n@Memory {{id: "{id}"}}) - """ - query += f'WHERE n.user_name = "{user_name}"' - - query += f"\nSET {set_clause_str}" - self.execute_query(query) - - @timed - def delete_node(self, id: str, user_name: str | None = None) -> None: - """ - Delete a node from the graph. - Args: - id: Node identifier to delete. - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self.config.user_name - query = f""" - MATCH (n@Memory {{id: "{id}"}}) WHERE n.user_name = {self._format_value(user_name)} - DETACH DELETE n - """ - self.execute_query(query) - - @timed - def add_edge(self, source_id: str, target_id: str, type: str, user_name: str | None = None): - """ - Create an edge from source node to target node. - Args: - source_id: ID of the source node. - target_id: ID of the target node. - type: Relationship type (e.g., 'RELATE_TO', 'PARENT'). - user_name (str, optional): User name for filtering in non-multi-db mode - """ - if not source_id or not target_id: - raise ValueError("[add_edge] source_id and target_id must be provided") - user_name = user_name if user_name else self.config.user_name - props = "" - props = f'{{user_name: "{user_name}"}}' - insert_stmt = f''' - MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) - INSERT (a) -[e@{type} {props}]-> (b) - ''' - try: - self.execute_query(insert_stmt) - except Exception as e: - logger.error(f"Failed to insert edge: {e}", exc_info=True) - - @timed - def delete_edge( - self, source_id: str, target_id: str, type: str, user_name: str | None = None - ) -> None: - """ - Delete a specific edge between two nodes. - Args: - source_id: ID of the source node. - target_id: ID of the target node. - type: Relationship type to remove. - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self.config.user_name - query = f""" - MATCH (a@Memory) -[r@{type}]-> (b@Memory) - WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)} - """ - - query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" - query += "\nDELETE r" - self.execute_query(query) - - @timed - def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: - user_name = user_name if user_name else self.config.user_name - query = f""" - MATCH (n@Memory) - WHERE n.memory_type = "{memory_type}" - """ - query += f"\nAND n.user_name = '{user_name}'" - query += "\nRETURN COUNT(n) AS count" - - try: - result = self.execute_query(query) - return result.one_or_none()["count"].value - except Exception as e: - logger.error(f"[get_memory_count] Failed: {e}") - return -1 - - @timed - def count_nodes(self, scope: str, user_name: str | None = None) -> int: - user_name = user_name if user_name else self.config.user_name - query = f""" - MATCH (n@Memory) - WHERE n.memory_type = "{scope}" - """ - query += f"\nAND n.user_name = '{user_name}'" - query += "\nRETURN count(n) AS count" - - result = self.execute_query(query) - return result.one_or_none()["count"].value - - @timed - def edge_exists( - self, - source_id: str, - target_id: str, - type: str = "ANY", - direction: str = "OUTGOING", - user_name: str | None = None, - ) -> bool: - """ - Check if an edge exists between two nodes. - Args: - source_id: ID of the source node. - target_id: ID of the target node. - type: Relationship type. Use "ANY" to match any relationship type. - direction: Direction of the edge. - Use "OUTGOING" (default), "INCOMING", or "ANY". - user_name (str, optional): User name for filtering in non-multi-db mode - Returns: - True if the edge exists, otherwise False. - """ - # Prepare the relationship pattern - user_name = user_name if user_name else self.config.user_name - rel = "r" if type == "ANY" else f"r@{type}" - - # Prepare the match pattern with direction - if direction == "OUTGOING": - pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]->(b@Memory {{id: '{target_id}'}})" - elif direction == "INCOMING": - pattern = f"(a@Memory {{id: '{source_id}'}})<-[{rel}]-(b@Memory {{id: '{target_id}'}})" - elif direction == "ANY": - pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]-(b@Memory {{id: '{target_id}'}})" - else: - raise ValueError( - f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." - ) - query = f"MATCH {pattern}" - query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" - query += "\nRETURN r" - - # Run the Cypher query - result = self.execute_query(query) - record = result.one_or_none() - if record is None: - return False - return record.values() is not None - - @timed - # Graph Query & Reasoning - def get_node( - self, id: str, include_embedding: bool = False, user_name: str | None = None - ) -> dict[str, Any] | None: - """ - Retrieve a Memory node by its unique ID. - - Args: - id (str): Node ID (Memory.id) - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - dict: Node properties as key-value pairs, or None if not found. - """ - filter_clause = f'n.id = "{id}"' - return_fields = self._build_return_fields(include_embedding) - gql = f""" - MATCH (n@Memory) - WHERE {filter_clause} - RETURN {return_fields} - """ - - try: - result = self.execute_query(gql) - for row in result: - props = {k: v.value for k, v in row.items()} - node = self._parse_node(props) - return node - - except Exception as e: - logger.error( - f"[get_node] Failed to retrieve node '{id}': {e}, trace: {traceback.format_exc()}" - ) - return None - - @timed - def get_nodes( - self, - ids: list[str], - include_embedding: bool = False, - user_name: str | None = None, - **kwargs, - ) -> list[dict[str, Any]]: - """ - Retrieve the metadata and memory of a list of nodes. - Args: - ids: List of Node identifier. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - Returns: - list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'. - - Notes: - - Assumes all provided IDs are valid and exist. - - Returns empty list if input is empty. - """ - if not ids: - return [] - # Safe formatting of the ID list - id_list = ",".join(f'"{_id}"' for _id in ids) - - return_fields = self._build_return_fields(include_embedding) - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE n.id IN [{id_list}] - RETURN {return_fields} - """ - nodes = [] - try: - results = self.execute_query(query) - for row in results: - props = {k: v.value for k, v in row.items()} - nodes.append(self._parse_node(props)) - except Exception as e: - logger.error( - f"[get_nodes] Failed to retrieve nodes {ids}: {e}, trace: {traceback.format_exc()}" - ) - return nodes - - @timed - def get_edges( - self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None - ) -> list[dict[str, str]]: - """ - Get edges connected to a node, with optional type and direction filter. - - Args: - id: Node ID to retrieve edges for. - type: Relationship type to match, or 'ANY' to match all. - direction: 'OUTGOING', 'INCOMING', or 'ANY'. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - List of edges: - [ - {"from": "source_id", "to": "target_id", "type": "RELATE"}, - ... - ] - """ - # Build relationship type filter - rel_type = "" if type == "ANY" else f"@{type}" - user_name = user_name if user_name else self.config.user_name - # Build Cypher pattern based on direction - if direction == "OUTGOING": - pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)" - where_clause = f"a.id = '{id}'" - elif direction == "INCOMING": - pattern = f"(a@Memory)<-[r{rel_type}]-(b@Memory)" - where_clause = f"a.id = '{id}'" - elif direction == "ANY": - pattern = f"(a@Memory)-[r{rel_type}]-(b@Memory)" - where_clause = f"a.id = '{id}' OR b.id = '{id}'" - else: - raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") - - where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" - - query = f""" - MATCH {pattern} - WHERE {where_clause} - RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type - """ - - result = self.execute_query(query) - edges = [] - for record in result: - edges.append( - { - "from": record["from_id"].value, - "to": record["to_id"].value, - "type": record["edge_type"].value, - } - ) - return edges - - @timed - def get_neighbors_by_tag( - self, - tags: list[str], - exclude_ids: list[str], - top_k: int = 5, - min_overlap: int = 1, - include_embedding: bool = False, - user_name: str | None = None, - ) -> list[dict[str, Any]]: - """ - Find top-K neighbor nodes with maximum tag overlap. - - Args: - tags: The list of tags to match. - exclude_ids: Node IDs to exclude (e.g., local cluster). - top_k: Max number of neighbors to return. - min_overlap: Minimum number of overlapping tags required. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - List of dicts with node details and overlap count. - """ - if not tags: - return [] - user_name = user_name if user_name else self.config.user_name - where_clauses = [ - 'n.status = "activated"', - 'NOT (n.node_type = "reasoning")', - 'NOT (n.memory_type = "WorkingMemory")', - ] - if exclude_ids: - where_clauses.append(f"NOT (n.id IN {exclude_ids})") - - where_clauses.append(f'n.user_name = "{user_name}"') - - where_clause = " AND ".join(where_clauses) - tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]" - - return_fields = self._build_return_fields(include_embedding) - query = f""" - LET tag_list = {tag_list_literal} - - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE {where_clause} - RETURN {return_fields}, - size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count - ORDER BY overlap_count DESC - LIMIT {top_k} - """ - - result = self.execute_query(query) - neighbors: list[dict[str, Any]] = [] - for r in result: - props = {k: v.value for k, v in r.items() if k != "overlap_count"} - parsed = self._parse_node(props) - parsed["overlap_count"] = r["overlap_count"].value - neighbors.append(parsed) - - neighbors.sort(key=lambda x: x["overlap_count"], reverse=True) - neighbors = neighbors[:top_k] - result = [] - for neighbor in neighbors[:top_k]: - neighbor.pop("overlap_count") - result.append(neighbor) - return result - - @timed - def get_children_with_embeddings( - self, id: str, user_name: str | None = None - ) -> list[dict[str, Any]]: - user_name = user_name if user_name else self.config.user_name - where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" - - query = f""" - MATCH (p@Memory)-[@PARENT]->(c@Memory) - WHERE p.id = "{id}" {where_user} - RETURN c.id AS id, c.{self.dim_field} AS {self.dim_field}, c.memory AS memory - """ - result = self.execute_query(query) - children = [] - for row in result: - eid = row["id"].value # STRING - emb_v = row[self.dim_field].value # NVector - emb = list(emb_v.values) if emb_v else [] - mem = row["memory"].value # STRING - - children.append({"id": eid, "embedding": emb, "memory": mem}) - return children - - @timed - def get_subgraph( - self, - center_id: str, - depth: int = 2, - center_status: str = "activated", - user_name: str | None = None, - ) -> dict[str, Any]: - """ - Retrieve a local subgraph centered at a given node. - Args: - center_id: The ID of the center node. - depth: The hop distance for neighbors. - center_status: Required status for center node. - user_name (str, optional): User name for filtering in non-multi-db mode - Returns: - { - "core_node": {...}, - "neighbors": [...], - "edges": [...] - } - """ - if not 1 <= depth <= 5: - raise ValueError("depth must be 1-5") - - user_name = user_name if user_name else self.config.user_name - - gql = f""" - MATCH (center@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE center.id = '{center_id}' - AND center.status = '{center_status}' - AND center.user_name = '{user_name}' - OPTIONAL MATCH p = (center)-[e]->{{1,{depth}}}(neighbor@Memory) - WHERE neighbor.user_name = '{user_name}' - RETURN center, - collect(DISTINCT neighbor) AS neighbors, - collect(EDGES(p)) AS edge_chains - """ - - result = self.execute_query(gql).one_or_none() - if not result or result.size == 0: - return {"core_node": None, "neighbors": [], "edges": []} - - core_node_props = result["center"].as_node().get_properties() - core_node = self._parse_node(core_node_props) - neighbors = [] - vid_to_id_map = {result["center"].as_node().node_id: core_node["id"]} - for n in result["neighbors"].value: - n_node = n.as_node() - n_props = n_node.get_properties() - node_parsed = self._parse_node(n_props) - neighbors.append(node_parsed) - vid_to_id_map[n_node.node_id] = node_parsed["id"] - - edges = [] - for chain_group in result["edge_chains"].value: - for edge_wr in chain_group.value: - edge = edge_wr.value - edges.append( - { - "type": edge.get_type(), - "source": vid_to_id_map.get(edge.get_src_id()), - "target": vid_to_id_map.get(edge.get_dst_id()), - } - ) - - return {"core_node": core_node, "neighbors": neighbors, "edges": edges} - - @timed - # Search / recall operations - def search_by_embedding( - self, - vector: list[float], - top_k: int = 5, - scope: str | None = None, - status: str | None = None, - threshold: float | None = None, - search_filter: dict | None = None, - user_name: str | None = None, - **kwargs, - ) -> list[dict]: - """ - Retrieve node IDs based on vector similarity. - - Args: - vector (list[float]): The embedding vector representing query semantics. - top_k (int): Number of top similar nodes to retrieve. - scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory'). - status (str, optional): Node status filter (e.g., 'active', 'archived'). - If provided, restricts results to nodes with matching status. - threshold (float, optional): Minimum similarity score threshold (0 ~ 1). - search_filter (dict, optional): Additional metadata filters for search results. - Keys should match node properties, values are the expected values. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. - - Notes: - - This method uses Neo4j native vector indexing to search for similar nodes. - - If scope is provided, it restricts results to nodes with matching memory_type. - - If 'status' is provided, only nodes with the matching status will be returned. - - If threshold is provided, only results with score >= threshold will be returned. - - If search_filter is provided, additional WHERE clauses will be added for metadata filtering. - - Typical use case: restrict to 'status = activated' to avoid - matching archived or merged nodes. - """ - user_name = user_name if user_name else self.config.user_name - vector = _normalize(vector) - dim = len(vector) - vector_str = ",".join(f"{float(x)}" for x in vector) - gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])" - where_clauses = [f"n.{self.dim_field} IS NOT NULL"] - if scope: - where_clauses.append(f'n.memory_type = "{scope}"') - if status: - where_clauses.append(f'n.status = "{status}"') - where_clauses.append(f'n.user_name = "{user_name}"') - - # Add search_filter conditions - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append(f'n.{key} = "{value}"') - else: - where_clauses.append(f"n.{key} = {value}") - - where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - - gql = f""" - let a = {gql_vector} - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - {where_clause} - ORDER BY inner_product(n.{self.dim_field}, a) DESC - LIMIT {top_k} - RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score""" - try: - result = self.execute_query(gql) - except Exception as e: - logger.error(f"[search_by_embedding] Query failed: {e}") - return [] - - try: - output = [] - for row in result: - values = row.values() - id_val = values[0].as_string() - score_val = values[1].as_double() - score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score - if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) - return output - except Exception as e: - logger.error(f"[search_by_embedding] Result parse failed: {e}") - return [] - - @timed - def get_by_metadata( - self, filters: list[dict[str, Any]], user_name: str | None = None - ) -> list[str]: - """ - 1. ADD logic: "AND" vs "OR"(support logic combination); - 2. Support nested conditional expressions; - - Retrieve node IDs that match given metadata filters. - Supports exact match. - - Args: - filters: List of filter dicts like: - [ - {"field": "key", "op": "in", "value": ["A", "B"]}, - {"field": "confidence", "op": ">=", "value": 80}, - {"field": "tags", "op": "contains", "value": "AI"}, - ... - ] - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[str]: Node IDs whose metadata match the filter conditions. (AND logic). - - Notes: - - Supports structured querying such as tag/category/importance/time filtering. - - Can be used for faceted recall or prefiltering before embedding rerank. - """ - where_clauses = [] - user_name = user_name if user_name else self.config.user_name - for _i, f in enumerate(filters): - field = f["field"] - op = f.get("op", "=") - value = f["value"] - - escaped_value = self._format_value(value) - - # Build WHERE clause - if op == "=": - where_clauses.append(f"n.{field} = {escaped_value}") - elif op == "in": - where_clauses.append(f"n.{field} IN {escaped_value}") - elif op == "contains": - where_clauses.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0") - elif op == "starts_with": - where_clauses.append(f"n.{field} STARTS WITH {escaped_value}") - elif op == "ends_with": - where_clauses.append(f"n.{field} ENDS WITH {escaped_value}") - elif op in [">", ">=", "<", "<="]: - where_clauses.append(f"n.{field} {op} {escaped_value}") - else: - raise ValueError(f"Unsupported operator: {op}") - - where_clauses.append(f'n.user_name = "{user_name}"') - - where_str = " AND ".join(where_clauses) - gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id" - ids = [] - try: - result = self.execute_query(gql) - ids = [record["id"].value for record in result] - except Exception as e: - logger.error(f"Failed to get metadata: {e}, gql is {gql}") - return ids - - @timed - def get_grouped_counts( - self, - group_fields: list[str], - where_clause: str = "", - params: dict[str, Any] | None = None, - user_name: str | None = None, - ) -> list[dict[str, Any]]: - """ - Count nodes grouped by any fields. - - Args: - group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"] - where_clause (str, optional): Extra WHERE condition. E.g., - "WHERE n.status = 'activated'" - params (dict, optional): Parameters for WHERE clause. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] - """ - if not group_fields: - raise ValueError("group_fields cannot be empty") - user_name = user_name if user_name else self.config.user_name - # GQL-specific modifications - user_clause = f"n.user_name = '{user_name}'" - if where_clause: - where_clause = where_clause.strip() - if where_clause.upper().startswith("WHERE"): - where_clause += f" AND {user_clause}" - else: - where_clause = f"WHERE {where_clause} AND {user_clause}" - else: - where_clause = f"WHERE {user_clause}" - - # Inline parameters if provided - if params: - for key, value in params.items(): - # Handle different value types appropriately - if isinstance(value, str): - value = f"'{value}'" - where_clause = where_clause.replace(f"${key}", str(value)) - - return_fields = [] - group_by_fields = [] - - for field in group_fields: - alias = field.replace(".", "_") - return_fields.append(f"n.{field} AS {alias}") - group_by_fields.append(alias) - # Full GQL query construction - gql = f""" - MATCH (n /*+ INDEX(idx_memory_user_name) */) - {where_clause} - RETURN {", ".join(return_fields)}, COUNT(n) AS count - """ - result = self.execute_query(gql) # Pure GQL string execution - - output = [] - for record in result: - group_values = {} - for i, field in enumerate(group_fields): - value = record.values()[i].as_string() - group_values[field] = value - count_value = record["count"].value - output.append({**group_values, "count": count_value}) - - return output - - @timed - def clear(self, user_name: str | None = None) -> None: - """ - Clear the entire graph if the target database exists. - - Args: - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self.config.user_name - try: - query = f"MATCH (n@Memory) WHERE n.user_name = '{user_name}' DETACH DELETE n" - self.execute_query(query) - logger.info("Cleared all nodes from database.") - - except Exception as e: - logger.error(f"[ERROR] Failed to clear database: {e}") - - @timed - def export_graph( - self, include_embedding: bool = False, user_name: str | None = None, **kwargs - ) -> dict[str, Any]: - """ - Export all graph nodes and edges in a structured form. - Args: - include_embedding (bool): Whether to include the large embedding field. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - { - "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], - "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] - } - """ - user_name = user_name if user_name else self.config.user_name - node_query = "MATCH (n@Memory)" - edge_query = "MATCH (a@Memory)-[r]->(b@Memory)" - node_query += f' WHERE n.user_name = "{user_name}"' - edge_query += f' WHERE r.user_name = "{user_name}"' - - try: - if include_embedding: - return_fields = "n" - else: - return_fields = ",".join( - [ - "n.id AS id", - "n.memory AS memory", - "n.user_name AS user_name", - "n.user_id AS user_id", - "n.session_id AS session_id", - "n.status AS status", - "n.key AS key", - "n.confidence AS confidence", - "n.tags AS tags", - "n.created_at AS created_at", - "n.updated_at AS updated_at", - "n.memory_type AS memory_type", - "n.sources AS sources", - "n.source AS source", - "n.node_type AS node_type", - "n.visibility AS visibility", - "n.usage AS usage", - "n.background AS background", - ] - ) - - full_node_query = f"{node_query} RETURN {return_fields}" - node_result = self.execute_query(full_node_query, timeout=20) - nodes = [] - logger.debug(f"Debugging: {node_result}") - for row in node_result: - if include_embedding: - props = row.values()[0].as_node().get_properties() - else: - props = {k: v.value for k, v in row.items()} - node = self._parse_node(props) - nodes.append(node) - except Exception as e: - raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e - - try: - full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge" - edge_result = self.execute_query(full_edge_query, timeout=20) - edges = [ - { - "source": row.values()[0].value, - "target": row.values()[1].value, - "type": row.values()[2].value, - } - for row in edge_result - ] - except Exception as e: - raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e - - return {"nodes": nodes, "edges": edges} - - @timed - def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: - """ - Import the entire graph from a serialized dictionary. - - Args: - data: A dictionary containing all nodes and edges to be loaded. - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self.config.user_name - for node in data.get("nodes", []): - try: - id, memory, metadata = _compose_node(node) - metadata["user_name"] = user_name - metadata = self._prepare_node_metadata(metadata) - metadata.update({"id": id, "memory": memory}) - properties = ", ".join( - f"{k}: {self._format_value(v, k)}" for k, v in metadata.items() - ) - node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})" - self.execute_query(node_gql) - except Exception as e: - logger.error(f"Fail to load node: {node}, error: {e}") - - for edge in data.get("edges", []): - try: - source_id, target_id = edge["source"], edge["target"] - edge_type = edge["type"] - props = f'{{user_name: "{user_name}"}}' - edge_gql = f''' - MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) - INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b) - ''' - self.execute_query(edge_gql) - except Exception as e: - logger.error(f"Fail to load edge: {edge}, error: {e}") - - @timed - def get_all_memory_items( - self, scope: str, include_embedding: bool = False, user_name: str | None = None - ) -> (list)[dict]: - """ - Retrieve all memory items of a specific memory_type. - - Args: - scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[dict]: Full list of memory items under this scope. - """ - user_name = user_name if user_name else self.config.user_name - if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: - raise ValueError(f"Unsupported memory type scope: {scope}") - - where_clause = f"WHERE n.memory_type = '{scope}'" - where_clause += f" AND n.user_name = '{user_name}'" - - return_fields = self._build_return_fields(include_embedding) - - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - {where_clause} - RETURN {return_fields} - LIMIT 100 - """ - nodes = [] - try: - results = self.execute_query(query) - for row in results: - props = {k: v.value for k, v in row.items()} - nodes.append(self._parse_node(props)) - except Exception as e: - logger.error(f"Failed to get memories: {e}") - return nodes - - @timed - def get_structure_optimization_candidates( - self, scope: str, include_embedding: bool = False, user_name: str | None = None - ) -> list[dict]: - """ - Find nodes that are likely candidates for structure optimization: - - Isolated nodes, nodes with empty background, or nodes with exactly one child. - - Plus: the child of any parent node that has exactly one child. - """ - user_name = user_name if user_name else self.config.user_name - where_clause = f''' - n.memory_type = "{scope}" - AND n.status = "activated" - ''' - where_clause += f' AND n.user_name = "{user_name}"' - - return_fields = self._build_return_fields(include_embedding) - return_fields += f", n.{self.dim_field} AS {self.dim_field}" - - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE {where_clause} - OPTIONAL MATCH (n)-[@PARENT]->(c@Memory) - OPTIONAL MATCH (p@Memory)-[@PARENT]->(n) - WHERE c IS NULL AND p IS NULL - RETURN {return_fields} - """ - - candidates = [] - node_ids = set() - try: - results = self.execute_query(query) - for row in results: - props = {k: v.value for k, v in row.items()} - node = self._parse_node(props) - node_id = node["id"] - if node_id not in node_ids: - candidates.append(node) - node_ids.add(node_id) - except Exception as e: - logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}") - return candidates - - @timed - def drop_database(self) -> None: - """ - Permanently delete the entire database this instance is using. - WARNING: This operation is destructive and cannot be undone. - """ - raise ValueError( - f"Refusing to drop protected database: `{self.db_name}` in " - f"Shared Database Multi-Tenant mode" - ) - - @timed - def detect_conflicts(self) -> list[tuple[str, str]]: - """ - Detect conflicting nodes based on logical or semantic inconsistency. - Returns: - A list of (node_id1, node_id2) tuples that conflict. - """ - raise NotImplementedError - - @timed - # Structure Maintenance - def deduplicate_nodes(self) -> None: - """ - Deduplicate redundant or semantically similar nodes. - This typically involves identifying nodes with identical or near-identical memory. - """ - raise NotImplementedError - - @timed - def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: - """ - Get the ordered context chain starting from a node, following a relationship type. - Args: - id: Starting node ID. - type: Relationship type to follow (e.g., 'FOLLOWS'). - Returns: - List of ordered node IDs in the chain. - """ - raise NotImplementedError - - @timed - def get_neighbors( - self, id: str, type: str, direction: Literal["in", "out", "both"] = "out" - ) -> list[str]: - """ - Get connected node IDs in a specific direction and relationship type. - Args: - id: Source node ID. - type: Relationship type. - direction: Edge direction to follow ('out', 'in', or 'both'). - Returns: - List of neighboring node IDs. - """ - raise NotImplementedError - - @timed - def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: - """ - Get the path of nodes from source to target within a limited depth. - Args: - source_id: Starting node ID. - target_id: Target node ID. - max_depth: Maximum path length to traverse. - Returns: - Ordered list of node IDs along the path. - """ - raise NotImplementedError - - @timed - def merge_nodes(self, id1: str, id2: str) -> str: - """ - Merge two similar or duplicate nodes into one. - Args: - id1: First node ID. - id2: Second node ID. - Returns: - ID of the resulting merged node. - """ - raise NotImplementedError - - @classmethod - def _ensure_space_exists(cls, tmp_client, cfg): - """Lightweight check to ensure target graph (space) exists.""" - db_name = getattr(cfg, "space", None) - if not db_name: - logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.") - return - - try: - res = tmp_client.execute("SHOW GRAPHS") - existing = {row.values()[0].as_string() for row in res} - if db_name not in existing: - tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type") - logger.info(f"โœ… Graph `{db_name}` created before session binding.") - else: - logger.debug(f"Graph `{db_name}` already exists.") - except Exception: - logger.exception("[NebulaGraphDBSync] Failed to ensure space exists") - - @timed - def _ensure_database_exists(self): - graph_type_name = "MemOSBgeM3Type" - - check_type_query = "SHOW GRAPH TYPES" - result = self.execute_query(check_type_query, auto_set_db=False) - - type_exists = any(row["graph_type"].as_string() == graph_type_name for row in result) - - if not type_exists: - create_tag = f""" - CREATE GRAPH TYPE IF NOT EXISTS {graph_type_name} AS {{ - NODE Memory (:MemoryTag {{ - id STRING, - memory STRING, - user_name STRING, - user_id STRING, - session_id STRING, - status STRING, - key STRING, - confidence FLOAT, - tags LIST, - created_at STRING, - updated_at STRING, - memory_type STRING, - sources LIST, - source STRING, - node_type STRING, - visibility STRING, - usage LIST, - background STRING, - {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT>, - PRIMARY KEY(id) - }}), - EDGE RELATE_TO (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE PARENT (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE AGGREGATE_TO (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE MERGED_TO (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE INFERS (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE FOLLOWS (Memory) -[{{user_name STRING}}]-> (Memory) - }} - """ - self.execute_query(create_tag, auto_set_db=False) - else: - describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name}" - desc_result = self.execute_query(describe_query, auto_set_db=False) - - memory_fields = [] - for row in desc_result: - field_name = row.values()[0].as_string() - memory_fields.append(field_name) - - if self.dim_field not in memory_fields: - alter_query = f""" - ALTER GRAPH TYPE {graph_type_name} {{ - ALTER NODE TYPE Memory ADD PROPERTIES {{ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT> }} - }} - """ - self.execute_query(alter_query, auto_set_db=False) - logger.info(f"โœ… Add new vector search {self.dim_field} to {graph_type_name}") - else: - logger.info(f"โœ… Graph Type {graph_type_name} already include {self.dim_field}") - - create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}" - try: - self.execute_query(create_graph, auto_set_db=False) - logger.info(f"โœ… Graph ``{self.db_name}`` is now the working graph.") - except Exception as e: - logger.error(f"โŒ Failed to create tag: {e} trace: {traceback.format_exc()}") - - @timed - def _create_vector_index( - self, - label: str = "Memory", - vector_property: str = "embedding", - dimensions: int = 3072, - index_name: str = "memory_vector_index", - ) -> None: - """ - Create a vector index for the specified property in the label. - """ - if str(dimensions) == str(self.default_memory_dimension): - index_name = f"idx_{vector_property}" - vector_name = vector_property - else: - index_name = f"idx_{vector_property}_{dimensions}" - vector_name = f"{vector_property}_{dimensions}" - - create_vector_index = f""" - CREATE VECTOR INDEX IF NOT EXISTS {index_name} - ON NODE {label}::{vector_name} - OPTIONS {{ - DIM: {dimensions}, - METRIC: IP, - TYPE: IVF, - NLIST: 100, - TRAINSIZE: 1000 - }} - FOR `{self.db_name}` - """ - self.execute_query(create_vector_index) - logger.info( - f"โœ… Ensure {label}::{vector_property} vector index {index_name} " - f"exists (DIM={dimensions})" - ) - - @timed - def _create_basic_property_indexes(self) -> None: - """ - Create standard B-tree indexes on status, memory_type, created_at - and updated_at fields. - Create standard B-tree indexes on user_name when use Shared Database - Multi-Tenant Mode. - """ - fields = [ - "status", - "memory_type", - "created_at", - "updated_at", - "user_name", - ] - - for field in fields: - index_name = f"idx_memory_{field}" - gql = f""" - CREATE INDEX IF NOT EXISTS {index_name} ON NODE Memory({field}) - FOR `{self.db_name}` - """ - try: - self.execute_query(gql) - logger.info(f"โœ… Created index: {index_name} on field {field}") - except Exception as e: - logger.error( - f"โŒ Failed to create index {index_name}: {e}, trace: {traceback.format_exc()}" - ) - - @timed - def _index_exists(self, index_name: str) -> bool: - """ - Check if an index with the given name exists. - """ - """ - Check if a vector index with the given name exists in NebulaGraph. - - Args: - index_name (str): The name of the index to check. - - Returns: - bool: True if the index exists, False otherwise. - """ - query = "SHOW VECTOR INDEXES" - try: - result = self.execute_query(query) - return any(row.values()[0].as_string() == index_name for row in result) - except Exception as e: - logger.error(f"[Nebula] Failed to check index existence: {e}") - return False - - @timed - def _parse_value(self, value: Any) -> Any: - """turn Nebula ValueWrapper to Python type""" - from nebulagraph_python.value_wrapper import ValueWrapper - - if value is None or (hasattr(value, "is_null") and value.is_null()): - return None - try: - prim = value.cast_primitive() if isinstance(value, ValueWrapper) else value - except Exception as e: - logger.warning(f"Error when decode Nebula ValueWrapper: {e}") - prim = value.cast() if isinstance(value, ValueWrapper) else value - - if isinstance(prim, ValueWrapper): - return self._parse_value(prim) - if isinstance(prim, list): - return [self._parse_value(v) for v in prim] - if type(prim).__name__ == "NVector": - return list(prim.values) - - return prim # already a Python primitive - - def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]: - parsed = {k: self._parse_value(v) for k, v in props.items()} - - for tf in ("created_at", "updated_at"): - if tf in parsed and parsed[tf] is not None: - parsed[tf] = _normalize_datetime(parsed[tf]) - - node_id = parsed.pop("id") - memory = parsed.pop("memory", "") - parsed.pop("user_name", None) - metadata = parsed - metadata["type"] = metadata.pop("node_type") - - if self.dim_field in metadata: - metadata["embedding"] = metadata.pop(self.dim_field) - - return {"id": node_id, "memory": memory, "metadata": metadata} - - @timed - def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """ - Ensure metadata has proper datetime fields and normalized types. - - - Fill `created_at` and `updated_at` if missing (in ISO 8601 format). - - Convert embedding to list of float if present. - """ - now = datetime.utcnow().isoformat() - metadata["node_type"] = metadata.pop("type") - - # Fill timestamps if missing - metadata.setdefault("created_at", now) - metadata.setdefault("updated_at", now) - - # Normalize embedding type - embedding = metadata.get("embedding") - if embedding and isinstance(embedding, list): - metadata.pop("embedding") - metadata[self.dim_field] = _normalize([float(x) for x in embedding]) - - return metadata - - @timed - def _format_value(self, val: Any, key: str = "") -> str: - from nebulagraph_python.py_data_types import NVector - - # None - if val is None: - return "NULL" - # bool - if isinstance(val, bool): - return "true" if val else "false" - # str - if isinstance(val, str): - return f'"{_escape_str(val)}"' - # num - elif isinstance(val, (int | float)): - return str(val) - # time - elif isinstance(val, datetime): - return f'datetime("{val.isoformat()}")' - # list - elif isinstance(val, list): - if key == self.dim_field: - dim = len(val) - joined = ",".join(str(float(x)) for x in val) - return f"VECTOR<{dim}, FLOAT>([{joined}])" - else: - return f"[{', '.join(self._format_value(v) for v in val)}]" - # NVector - elif isinstance(val, NVector): - if key == self.dim_field: - dim = len(val) - joined = ",".join(str(float(x)) for x in val) - return f"VECTOR<{dim}, FLOAT>([{joined}])" - else: - logger.warning("Invalid NVector") - # dict - if isinstance(val, dict): - j = json.dumps(val, ensure_ascii=False, separators=(",", ":")) - return f'"{_escape_str(j)}"' - else: - return f'"{_escape_str(str(val))}"' - - @timed - def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]: - """ - Filter and validate metadata dictionary against the Memory node schema. - - Removes keys not in schema. - - Warns if required fields are missing. - """ - - dim_fields = {self.dim_field} - - allowed_fields = self.common_fields | dim_fields - - missing_fields = allowed_fields - metadata.keys() - if missing_fields: - logger.info(f"Metadata missing required fields: {sorted(missing_fields)}") - - filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields} - - return filtered_metadata - - def _build_return_fields(self, include_embedding: bool = False) -> str: - fields = set(self.base_fields) - if include_embedding: - fields.add(self.dim_field) - return ", ".join(f"n.{f} AS {f}" for f in fields) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 33eb39692..56c3e08a0 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -39,7 +39,7 @@ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: metadata["embedding"] = [float(x) for x in embedding] # serialization - if metadata["sources"]: + if metadata.get("sources"): for idx in range(len(metadata["sources"])): metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) return metadata @@ -72,8 +72,35 @@ def _flatten_info_fields(metadata: dict[str, Any]) -> dict[str, Any]: return metadata +def _sanitize_neo4j_value(value: Any) -> Any: + """Convert values unsupported by Neo4j properties into safe serializations.""" + if value is None or isinstance(value, str | int | float | bool): + return value + + if isinstance(value, list): + if all(item is None or isinstance(item, str | int | float | bool) for item in value): + return value + return [ + json.dumps(item, ensure_ascii=False) if isinstance(item, dict | list) else str(item) + for item in value + ] + + if isinstance(value, dict): + return json.dumps(value, ensure_ascii=False, sort_keys=True) + + return str(value) + + +def _sanitize_neo4j_metadata(metadata: dict[str, Any]) -> dict[str, Any]: + """Ensure all metadata values are valid Neo4j property types.""" + return {key: _sanitize_neo4j_value(value) for key, value in metadata.items()} + + class Neo4jGraphDB(BaseGraphDB): - """Neo4j-based implementation of a graph memory store.""" + """Neo4j-based implementation of a graph memory store. + + Requires Neo4j >= 5.18 for vector.similarity.cosine() pre-filtering support. + """ @require_python_package( import_name="neo4j", @@ -209,6 +236,9 @@ def add_node( # Flatten info fields to top level (for Neo4j flat structure) metadata = _flatten_info_fields(metadata) + # Ensure Neo4j property compatibility (no nested map/list-of-map values) + metadata = _sanitize_neo4j_metadata(metadata) + # Initialize delete_time and delete_record_id fields metadata.setdefault("delete_time", "") metadata.setdefault("delete_record_id", "") @@ -226,7 +256,7 @@ def add_node( """ # serialization - if metadata["sources"]: + if metadata.get("sources"): for idx in range(len(metadata["sources"])): metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) @@ -843,13 +873,14 @@ def search_by_embedding( If return_fields is specified, each dict also includes the requested fields. Notes: - - This method uses Neo4j native vector indexing to search for similar nodes. - - If scope is provided, it restricts results to nodes with matching memory_type. - - If 'status' is provided, only nodes with the matching status will be returned. + - When filters are present (scope, status, user_name, etc.), this method uses + Neo4j 5.18+ pre-filtering: MATCH + WHERE narrows candidates first, then + vector.similarity.cosine() computes similarity only on the filtered set. + This avoids the post-filter problem where queryNodes' global top-k excludes + the target user's nodes in a multi-tenant shared database. + - When no filters are present, the ANN vector index (db.index.vector.queryNodes) + is used for maximum efficiency. - If threshold is provided, only results with score >= threshold will be returned. - - If search_filter is provided, additional WHERE clauses will be added for metadata filtering. - - Typical use case: restrict to 'status = activated' to avoid - matching archived or merged nodes. """ user_name = user_name if user_name else self.config.user_name # Build WHERE clause dynamically @@ -901,14 +932,28 @@ def search_by_embedding( if extra_fields: return_clause = f"RETURN node.id AS id, score, {extra_fields}" - query = f""" - CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding) - YIELD node, score - {where_clause} - {return_clause} - """ - - parameters = {"embedding": vector, "k": top_k} + if where_clause: + # Pre-filtering (Neo4j 5.18+): filter nodes first, then compute similarity. + # This avoids the post-filter problem where relevant nodes are excluded + # from the global top-k returned by queryNodes. + where_clause += " AND node.embedding IS NOT NULL" + query = f""" + MATCH (node:Memory) + {where_clause} + WITH node, vector.similarity.cosine(node.embedding, $embedding) AS score + {return_clause} + ORDER BY score DESC + LIMIT $top_k + """ + parameters = {"embedding": vector, "top_k": top_k} + else: + # No filter: use ANN vector index for efficiency. + query = f""" + CALL db.index.vector.queryNodes('memory_vector_index', $top_k, $embedding) + YIELD node, score + {return_clause} + """ + parameters = {"embedding": vector, "top_k": top_k} if scope: parameters["scope"] = scope @@ -1842,7 +1887,7 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: if not ( isinstance(node["sources"][idx], str) and node["sources"][idx][0] == "{" - and node["sources"][idx][0] == "}" + and node["sources"][idx][-1] == "}" ): break node["sources"][idx] = json.loads(node["sources"][idx]) diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 283e15115..b5c92f40a 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -5,7 +5,12 @@ from typing import Any from memos.configs.graph_db import Neo4jGraphDBConfig -from memos.graph_dbs.neo4j import Neo4jGraphDB, _flatten_info_fields, _prepare_node_metadata +from memos.graph_dbs.neo4j import ( + Neo4jGraphDB, + _flatten_info_fields, + _prepare_node_metadata, + _sanitize_neo4j_metadata, +) from memos.log import get_logger from memos.vec_dbs.factory import VecDBFactory from memos.vec_dbs.item import VecDBItem @@ -55,40 +60,43 @@ def add_node( # Safely process metadata metadata = _prepare_node_metadata(metadata) + metadata = _flatten_info_fields(metadata) + metadata = _sanitize_neo4j_metadata(metadata) # Initialize delete_time and delete_record_id fields metadata.setdefault("delete_time", "") metadata.setdefault("delete_record_id", "") # serialization - if metadata["sources"]: + if metadata.get("sources"): for idx in range(len(metadata["sources"])): metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) # Extract required fields embedding = metadata.pop("embedding", None) - if embedding is None: - raise ValueError(f"Missing 'embedding' in metadata for node {id}") # Merge node and set metadata created_at = metadata.pop("created_at") updated_at = metadata.pop("updated_at") - vector_sync_status = "success" + vector_sync_status = "skipped" - try: - # Write to Vector DB - item = VecDBItem( - id=id, - vector=embedding, - payload={ - "memory": memory, - "vector_sync": vector_sync_status, - **metadata, # unpack all metadata keys to top-level - }, - ) - self.vec_db.add([item]) - except Exception as e: - logger.warning(f"[VecDB] Vector insert failed for node {id}: {e}") - vector_sync_status = "failed" + if embedding is not None: + vector_sync_status = "success" + try: + item = VecDBItem( + id=id, + vector=embedding, + payload={ + "memory": memory, + "vector_sync": vector_sync_status, + **metadata, + }, + ) + self.vec_db.add([item]) + except Exception as e: + logger.warning(f"[VecDB] Vector insert failed for node {id}: {e}") + vector_sync_status = "failed" + else: + logger.warning(f"[add_node] No embedding for node {id}, skipping vector DB insert") metadata["vector_sync"] = vector_sync_status query = """ @@ -134,6 +142,7 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N metadata = _prepare_node_metadata(metadata) metadata = _flatten_info_fields(metadata) + metadata = _sanitize_neo4j_metadata(metadata) # Initialize delete_time and delete_record_id fields metadata.setdefault("delete_time", "") @@ -141,18 +150,21 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N embedding = metadata.pop("embedding", None) - vector_sync_status = "success" - vec_items.append( - VecDBItem( - id=node_id, - vector=embedding, - payload={ - "memory": memory, - "vector_sync": vector_sync_status, - **metadata, - }, + if embedding is not None: + vector_sync_status = "success" + vec_items.append( + VecDBItem( + id=node_id, + vector=embedding, + payload={ + "memory": memory, + "vector_sync": vector_sync_status, + **metadata, + }, + ) ) - ) + else: + vector_sync_status = "skipped" created_at = metadata.pop("created_at") updated_at = metadata.pop("updated_at") @@ -889,6 +901,14 @@ def build_filter_condition( if condition_str: where_clauses.append(f"({condition_str})") filter_params.update(filter_params_inner) + else: + # Simple dict syntax: {"user_id": "...", "session_id": "..."} + condition_str, filter_params_inner = build_filter_condition( + filter, param_counter + ) + if condition_str: + where_clauses.append(f"({condition_str})") + filter_params.update(filter_params_inner) where_str = " AND ".join(where_clauses) if where_clauses else "" if where_str: @@ -919,7 +939,7 @@ def build_filter_condition( def delete_node_by_prams( self, - writable_cube_ids: list[str], + writable_cube_ids: list[str] | None = None, memory_ids: list[str] | None = None, file_ids: list[str] | None = None, filter: dict | None = None, @@ -928,7 +948,7 @@ def delete_node_by_prams( Delete nodes by memory_ids, file_ids, or filter. Args: - writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter. + writable_cube_ids (list[str], optional): List of cube IDs (user_name) to scope deletion. memory_ids (list[str], optional): List of memory node IDs to delete. file_ids (list[str], optional): List of file node IDs to delete. filter (dict, optional): Filter dictionary to query matching nodes for deletion. @@ -943,9 +963,9 @@ def delete_node_by_prams( f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" ) - # Validate writable_cube_ids - if not writable_cube_ids or len(writable_cube_ids) == 0: - raise ValueError("writable_cube_ids is required and cannot be empty") + # file_ids deletion must be scoped by writable_cube_ids. + if file_ids and (not writable_cube_ids or len(writable_cube_ids) == 0): + raise ValueError("writable_cube_ids is required when deleting by file_ids") # Build WHERE conditions separately for memory_ids and file_ids where_clauses = [] @@ -953,10 +973,11 @@ def delete_node_by_prams( # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) user_name_conditions = [] - for idx, cube_id in enumerate(writable_cube_ids): - param_name = f"cube_id_{idx}" - user_name_conditions.append(f"n.user_name = ${param_name}") - params[param_name] = cube_id + if writable_cube_ids: + for idx, cube_id in enumerate(writable_cube_ids): + param_name = f"cube_id_{idx}" + user_name_conditions.append(f"n.user_name = ${param_name}") + params[param_name] = cube_id # Handle memory_ids: query n.id if memory_ids and len(memory_ids) > 0: @@ -1003,9 +1024,12 @@ def delete_node_by_prams( # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match) data_conditions = " OR ".join([f"({clause})" for clause in where_clauses]) - # Then, combine with user_name condition using AND (must match user_name AND one of the data conditions) - user_name_where = " OR ".join(user_name_conditions) - ids_where = f"({user_name_where}) AND ({data_conditions})" + # Then, combine with user_name condition using AND when scope is provided. + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + ids_where = f"({user_name_where}) AND ({data_conditions})" + else: + ids_where = data_conditions logger.info( f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" @@ -1138,12 +1162,12 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: node[time_field] = node[time_field].isoformat() node.pop("user_name", None) # serialization - if node["sources"]: + if node.get("sources"): for idx in range(len(node["sources"])): if not ( isinstance(node["sources"][idx], str) and node["sources"][idx][0] == "{" - and node["sources"][idx][0] == "}" + and node["sources"][idx][-1] == "}" ): break node["sources"][idx] = json.loads(node["sources"][idx]) @@ -1179,7 +1203,7 @@ def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]] if not ( isinstance(node["sources"][idx], str) and node["sources"][idx][0] == "{" - and node["sources"][idx][0] == "}" + and node["sources"][idx][-1] == "}" ): break node["sources"][idx] = json.loads(node["sources"][idx]) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 4d88844df..abfae7710 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -165,7 +165,7 @@ def __init__(self, config: PolarDBGraphDBConfig): # Create connection pool self.connection_pool = psycopg2.pool.ThreadedConnectionPool( - minconn=5, + minconn=1, maxconn=maxconn, host=host, port=port, @@ -176,6 +176,7 @@ def __init__(self, config: PolarDBGraphDBConfig): keepalives_idle=120, # Seconds of inactivity before sending keepalive (should be < server idle timeout) keepalives_interval=15, # Seconds between keepalive retries keepalives_count=5, # Number of keepalive retries before considering connection dead + keepalives=1, options=f"-c search_path={self.db_name}_graph,ag_catalog,$user,public", ) @@ -250,39 +251,40 @@ def _get_connection(self): import psycopg2 timeout = self._connection_wait_timeout - if not self._semaphore.acquire(timeout=max(timeout, 0)): + if timeout is None or timeout <= 0: + self._semaphore.acquire() + elif not self._semaphore.acquire(timeout=timeout): logger.warning(f"Timeout waiting for connection slot ({timeout}s)") raise RuntimeError("Connection pool busy") - logger.info( - "Connection pool usage: %s/%s", - self.connection_pool.maxconn - self._semaphore._value, - self.connection_pool.maxconn, - ) + conn = None broken = False try: - conn = self.connection_pool.getconn() - conn.autocommit = True for attempt in range(2): + conn = self.connection_pool.getconn() + conn.autocommit = True try: with conn.cursor() as cur: cur.execute("SELECT 1") break except psycopg2.Error: - logger.warning("Dead connection detected, recreating (attempt %d)", attempt + 1) + logger.warning(f"Dead connection detected, recreating (attempt {attempt + 1})") self.connection_pool.putconn(conn, close=True) - conn = self.connection_pool.getconn() - conn.autocommit = True + conn = None else: raise RuntimeError("Cannot obtain valid DB connection after 2 attempts") with conn.cursor() as cur: cur.execute(f'SET search_path = {self.db_name}_graph, ag_catalog, "$user", public;') yield conn - except Exception: + except (psycopg2.Error, psycopg2.OperationalError) as e: broken = True + logger.exception(f"Database connection busy : {e}") + raise + except Exception as e: + logger.exception(f"Unexpected error: {e}") raise finally: - if conn: + if conn is not None: try: self.connection_pool.putconn(conn, close=broken) logger.debug(f"Returned connection {id(conn)} to pool (broken={broken})") @@ -441,7 +443,7 @@ def remove_oldest_memory( ) user_name = user_name if user_name else self._get_config_value("user_name") - # Use actual OFFSET logic, consistent with nebular.py + # Use actual OFFSET logic for deterministic pruning # First find IDs to delete, then delete them select_query = f""" SELECT id FROM "{self.db_name}_graph"."Memory" @@ -1798,7 +1800,6 @@ def search_by_fulltext( FROM "{self.db_name}_graph"."Memory" m CROSS JOIN q {where_clause_cte} - ORDER BY rank DESC LIMIT {top_k}; """ params = [tsquery_string] @@ -2194,7 +2195,7 @@ def get_grouped_counts( SELECT {", ".join(cte_select_list)} FROM "{self.db_name}_graph"."Memory" {where_clause} - LIMIT 1000 + LIMIT 100 ) SELECT {outer_select}, count(*) AS count FROM t @@ -2409,56 +2410,47 @@ def _extract_special_filter_values(filter_obj): order_clause = """ ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST,id DESC """ + count_query = f""" + SELECT COUNT(*) AS total_count + FROM "{self.db_name}_graph"."Memory" + {where_clause} + """ if include_embedding: - node_query = f""" - WITH filtered AS ( - SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" - {where_clause} - ) - SELECT p.id, p.properties, p.embedding, c.total_count - FROM (SELECT COUNT(*) AS total_count FROM filtered) c - LEFT JOIN LATERAL ( - SELECT id, properties, embedding - FROM filtered - {order_clause} - {pagination_clause} - ) p ON TRUE + data_query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + {where_clause} + {order_clause} + {pagination_clause} """ else: - node_query = f""" - WITH filtered AS ( - SELECT id, properties - FROM "{self.db_name}_graph"."Memory" - {where_clause} - ) - SELECT p.id, p.properties, c.total_count - FROM (SELECT COUNT(*) AS total_count FROM filtered) c - LEFT JOIN LATERAL ( - SELECT id, properties - FROM filtered - {order_clause} - {pagination_clause} - ) p ON TRUE + data_query = f""" + SELECT id, properties + FROM "{self.db_name}_graph"."Memory" + {where_clause} + {order_clause} + {pagination_clause} """ - logger.info(f"[export_graph nodes] Query: {node_query}") + logger.info(f"[export_graph nodes] count_query: {count_query}") + logger.info(f"[export_graph nodes] data_query: {data_query}") try: with self._get_connection() as conn, conn.cursor() as cursor: - cursor.execute(node_query) + cursor.execute(count_query) + count_row = cursor.fetchone() + total_nodes = int(count_row[0]) if count_row and count_row[0] is not None else 0 + + cursor.execute(data_query) node_results = cursor.fetchall() nodes = [] for row in node_results: if include_embedding: - row_id, properties_json, embedding_json, row_total_count = row + row_id, properties_json, embedding_json = row else: - row_id, properties_json, row_total_count = row + row_id, properties_json = row embedding_json = None - if row_total_count is not None: - total_nodes = int(row_total_count) - if row_id is None: continue @@ -3539,7 +3531,7 @@ def get_neighbors_by_tag_ccl( user_name = user_name if user_name else self._get_config_value("user_name") - # Build query conditions; keep consistent with nebular.py + # Build query conditions shared with other graph backends where_clauses = [ 'n.status = "activated"', 'NOT (n.node_type = "reasoning")', @@ -3592,7 +3584,7 @@ def get_neighbors_by_tag_ccl( # Add overlap_count result_fields.append("overlap_count agtype") result_fields_str = ", ".join(result_fields) - # Use Cypher query; keep consistent with nebular.py + # Use Cypher query to keep the graph query path aligned query = f""" SELECT * FROM ( SELECT * FROM cypher('{self.db_name}_graph', $$ diff --git a/src/memos/graph_dbs/postgres.py b/src/memos/graph_dbs/postgres.py index 1c1cae378..594f7e695 100644 --- a/src/memos/graph_dbs/postgres.py +++ b/src/memos/graph_dbs/postgres.py @@ -10,6 +10,7 @@ """ import json +import re import time from contextlib import suppress @@ -438,6 +439,202 @@ def _parse_row(self, row, include_embedding: bool = False) -> dict[str, Any]: result["metadata"]["embedding"] = row[5] return result + @staticmethod + def _is_safe_field_name(field: str) -> bool: + """Validate field names used in dynamic SQL fragments.""" + return bool(re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", field)) + + def _field_expr(self, key: str) -> tuple[str, str]: + """ + Build text/json SQL expressions for a filter key. + + Returns: + tuple[text_expr, json_expr] + """ + direct_columns = {"id", "memory", "user_name", "created_at", "updated_at"} + if key in direct_columns: + return key, key + + if key.startswith("info."): + sub_key = key[5:] + if not self._is_safe_field_name(sub_key): + raise ValueError(f"Invalid filter field: {key}") + return f"properties->'info'->>'{sub_key}'", f"properties->'info'->'{sub_key}'" + + if not self._is_safe_field_name(key): + raise ValueError(f"Invalid filter field: {key}") + return f"properties->>'{key}'", f"properties->'{key}'" + + def _build_single_filter_condition( + self, condition_dict: dict[str, Any], params: list[Any] + ) -> str | None: + """Build SQL for a single filter condition dict.""" + if not condition_dict: + return None + + array_fields = {"tags", "sources", "file_ids"} + timestamp_fields = {"created_at", "updated_at"} + parts: list[str] = [] + + for key, value in condition_dict.items(): + text_expr, json_expr = self._field_expr(key) + raw_key = key[5:] if key.startswith("info.") else key + + if isinstance(value, dict): + for op, op_value in value.items(): + if op in ("gt", "lt", "gte", "lte"): + op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="} + sql_op = op_map[op] + if raw_key in timestamp_fields or raw_key.endswith("_at"): + parts.append(f"({text_expr})::timestamptz {sql_op} %s::timestamptz") + params.append(op_value) + else: + parts.append(f"NULLIF({text_expr}, '')::numeric {sql_op} %s") + params.append(op_value) + elif op == "contains": + if raw_key in array_fields: + parts.append(f"{json_expr} @> %s::jsonb") + params.append(json.dumps([op_value])) + else: + parts.append(f"{text_expr} ILIKE %s") + params.append(f"%{op_value}%") + elif op == "in": + if not isinstance(op_value, list): + raise ValueError( + f"in operator expects list for '{key}', got {type(op_value).__name__}" + ) + if raw_key in array_fields: + parts.append(f"{json_expr} ?| %s") + params.append([str(v) for v in op_value]) + else: + parts.append(f"{text_expr} = ANY(%s)") + params.append([str(v) for v in op_value]) + elif op == "like": + parts.append(f"{text_expr} ILIKE %s") + params.append(f"%{op_value}%") + else: + raise ValueError(f"Unsupported filter operator: {op}") + else: + if raw_key in array_fields: + if isinstance(value, list): + parts.append(f"{json_expr} @> %s::jsonb") + params.append(json.dumps(value)) + else: + parts.append(f"{json_expr} @> %s::jsonb") + params.append(json.dumps([value])) + else: + parts.append(f"{text_expr} = %s") + params.append(str(value)) + + if not parts: + return None + return " AND ".join(parts) + + def _build_filter_where_clause(self, filter_dict: dict[str, Any], params: list[Any]) -> str: + """Build SQL WHERE fragment from filter dict.""" + if not filter_dict: + return "" + + if "and" in filter_dict: + and_conditions = filter_dict.get("and") + if not isinstance(and_conditions, list): + raise ValueError("Invalid filter format: 'and' must be a list") + parts: list[str] = [] + for cond in and_conditions: + if isinstance(cond, dict): + cond_sql = self._build_single_filter_condition(cond, params) + if cond_sql: + parts.append(f"({cond_sql})") + return " AND ".join(parts) + + if "or" in filter_dict: + or_conditions = filter_dict.get("or") + if not isinstance(or_conditions, list): + raise ValueError("Invalid filter format: 'or' must be a list") + parts: list[str] = [] + for cond in or_conditions: + if isinstance(cond, dict): + cond_sql = self._build_single_filter_condition(cond, params) + if cond_sql: + parts.append(f"({cond_sql})") + return f"({' OR '.join(parts)})" if parts else "" + + cond_sql = self._build_single_filter_condition(filter_dict, params) + return cond_sql or "" + + def delete_node_by_prams( + self, + writable_cube_ids: list[str] | None = None, + memory_ids: list[str] | None = None, + file_ids: list[str] | None = None, + filter: dict | None = None, + ) -> int: + """Delete nodes by memory_ids, file_ids, or filter.""" + logger.info( + "[delete_node_by_prams] memory_ids: %s, file_ids: %s, filter: %s, writable_cube_ids: %s", + memory_ids, + file_ids, + filter, + writable_cube_ids, + ) + + where_conditions: list[str] = [] + params: list[Any] = [] + + if memory_ids: + where_conditions.append("id = ANY(%s)") + params.append(memory_ids) + + if file_ids: + file_conditions: list[str] = [] + for file_id in file_ids: + file_conditions.append("properties->'file_ids' @> %s::jsonb") + params.append(json.dumps([file_id])) + if file_conditions: + where_conditions.append(f"({' OR '.join(file_conditions)})") + + if filter: + filter_where = self._build_filter_where_clause(filter, params) + if filter_where: + where_conditions.append(f"({filter_where})") + + if not where_conditions: + logger.warning( + "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" + ) + return 0 + + if writable_cube_ids: + where_conditions.append("user_name = ANY(%s)") + params.append(writable_cube_ids) + + where_clause = " AND ".join(where_conditions) + + conn = self._get_conn() + try: + with conn.cursor() as cur: + query = f""" + WITH to_delete AS ( + SELECT id + FROM {self.schema}.memories + WHERE {where_clause} + ), + deleted_edges AS ( + DELETE FROM {self.schema}.edges e + USING to_delete d + WHERE e.source_id = d.id OR e.target_id = d.id + ) + DELETE FROM {self.schema}.memories m + USING to_delete d + WHERE m.id = d.id + """ + cur.execute(query, params) + deleted_count = cur.rowcount if cur.rowcount is not None else 0 + logger.info("[delete_node_by_prams] Deleted %s nodes", deleted_count) + return deleted_count + finally: + self._put_conn(conn) + # ========================================================================= # Edge Management # ========================================================================= diff --git a/src/memos/llms/factory.py b/src/memos/llms/factory.py index 8f4da662f..20b27593e 100644 --- a/src/memos/llms/factory.py +++ b/src/memos/llms/factory.py @@ -5,6 +5,7 @@ from memos.llms.deepseek import DeepSeekLLM from memos.llms.hf import HFLLM from memos.llms.hf_singleton import HFSingletonLLM +from memos.llms.minimax import MinimaxLLM from memos.llms.ollama import OllamaLLM from memos.llms.openai import AzureLLM, OpenAILLM from memos.llms.openai_new import OpenAIResponsesLLM @@ -25,6 +26,7 @@ class LLMFactory(BaseLLM): "vllm": VLLMLLM, "qwen": QwenLLM, "deepseek": DeepSeekLLM, + "minimax": MinimaxLLM, "openai_new": OpenAIResponsesLLM, } diff --git a/src/memos/llms/minimax.py b/src/memos/llms/minimax.py new file mode 100644 index 000000000..3bee9882f --- /dev/null +++ b/src/memos/llms/minimax.py @@ -0,0 +1,13 @@ +from memos.configs.llm import MinimaxLLMConfig +from memos.llms.openai import OpenAILLM +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class MinimaxLLM(OpenAILLM): + """MiniMax LLM class via OpenAI-compatible API.""" + + def __init__(self, config: MinimaxLLMConfig): + super().__init__(config) diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 135058a7d..a57a40676 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -2,6 +2,7 @@ import difflib import json import re +import uuid from datetime import datetime from typing import TYPE_CHECKING, Any, Literal @@ -236,6 +237,7 @@ def _single_add_operation( else: to_add_memory = new_memory_item.model_copy(deep=True) + to_add_memory.id = str(uuid.uuid4()) if to_add_memory.metadata.memory_type == "PreferenceMemory": to_add_memory.metadata.preference = new_memory_item.memory @@ -359,9 +361,14 @@ def semantics_feedback( lang = detect_lang("".join(memory_item.memory)) template = FEEDBACK_PROMPT_DICT["compare"][lang] if current_memories == []: - # retrieve - last_user_index = max(i for i, d in enumerate(chat_history_list) if d["role"] == "user") - last_qa = " ".join([item["content"] for item in chat_history_list[last_user_index:]]) + user_indices = [i for i, d in enumerate(chat_history_list) if d["role"] == "user"] + if user_indices: + last_user_index = max(user_indices) + last_qa = " ".join( + [item["content"] for item in chat_history_list[last_user_index:]] + ) + else: + last_qa = " ".join([item["content"] for item in chat_history_list]) supplementary_retrieved = self._retrieve(last_qa, info=info, user_name=user_name) feedback_retrieved = self._retrieve(memory_item.memory, info=info, user_name=user_name) diff --git a/src/memos/mem_os/client.py b/src/memos/mem_os/client.py deleted file mode 100644 index f4a591e59..000000000 --- a/src/memos/mem_os/client.py +++ /dev/null @@ -1,5 +0,0 @@ -# TODO: @Li Ji - - -class ClientMOS: - pass diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py deleted file mode 100644 index b2c74c384..000000000 --- a/src/memos/mem_os/product.py +++ /dev/null @@ -1,1610 +0,0 @@ -import asyncio -import json -import os -import random -import time - -from collections.abc import Generator -from datetime import datetime -from typing import Any, Literal - -from dotenv import load_dotenv -from transformers import AutoTokenizer - -from memos.configs.mem_cube import GeneralMemCubeConfig -from memos.configs.mem_os import MOSConfig -from memos.context.context import ContextThread -from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube -from memos.mem_os.core import MOSCore -from memos.mem_os.utils.format_utils import ( - clean_json_response, - convert_graph_to_tree_forworkmem, - ensure_unique_tree_ids, - filter_nodes_by_tree_ids, - remove_embedding_recursive, - sort_children_by_memory_type, -) -from memos.mem_os.utils.reference_utils import ( - prepare_reference_data, - process_streaming_references_complete, -) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.schemas.task_schemas import ( - ANSWER_TASK_LABEL, - QUERY_TASK_LABEL, -) -from memos.mem_user.persistent_factory import PersistentUserManagerFactory -from memos.mem_user.user_manager import UserRole -from memos.memories.textual.item import ( - TextualMemoryItem, -) -from memos.templates.mos_prompts import ( - FURTHER_SUGGESTION_PROMPT, - SUGGESTION_QUERY_PROMPT_EN, - SUGGESTION_QUERY_PROMPT_ZH, - get_memos_prompt, -) -from memos.types import MessageList -from memos.utils import timed - - -logger = get_logger(__name__) - -load_dotenv() - -CUBE_PATH = os.getenv("MOS_CUBE_PATH", "/tmp/data/") - - -def _short_id(mem_id: str) -> str: - return (mem_id or "").split("-")[0] if mem_id else "" - - -def _format_mem_block(memories_all, max_items: int = 20, max_chars_each: int = 320) -> str: - """ - Modify TextualMemoryItem Format: - 1:abcd :: [P] text... - 2:ef01 :: [O] text... - sequence is [i:memId] i; [P]=PersonalMemory / [O]=OuterMemory - """ - if not memories_all: - return "(none)", "(none)" - - lines_o = [] - lines_p = [] - for idx, m in enumerate(memories_all[:max_items], 1): - mid = _short_id(getattr(m, "id", "") or "") - mtype = getattr(getattr(m, "metadata", {}), "memory_type", None) or getattr( - m, "metadata", {} - ).get("memory_type", "") - tag = "O" if "Outer" in str(mtype) else "P" - txt = (getattr(m, "memory", "") or "").replace("\n", " ").strip() - if len(txt) > max_chars_each: - txt = txt[: max_chars_each - 1] + "โ€ฆ" - mid = mid or f"mem_{idx}" - if tag == "O": - lines_o.append(f"[{idx}:{mid}] :: [{tag}] {txt}\n") - elif tag == "P": - lines_p.append(f"[{idx}:{mid}] :: [{tag}] {txt}") - return "\n".join(lines_o), "\n".join(lines_p) - - -class MOSProduct(MOSCore): - """ - The MOSProduct class inherits from MOSCore and manages multiple users. - Each user has their own configuration and cube access, but shares the same model instances. - """ - - def __init__( - self, - default_config: MOSConfig | None = None, - max_user_instances: int = 1, - default_cube_config: GeneralMemCubeConfig | None = None, - online_bot=None, - error_bot=None, - ): - """ - Initialize MOSProduct with an optional default configuration. - - Args: - default_config (MOSConfig | None): Default configuration for new users - max_user_instances (int): Maximum number of user instances to keep in memory - default_cube_config (GeneralMemCubeConfig | None): Default cube configuration for loading cubes - online_bot: DingDing online_bot function or None if disabled - error_bot: DingDing error_bot function or None if disabled - """ - # Initialize with a root config for shared resources - if default_config is None: - # Create a minimal config for root user - root_config = MOSConfig( - user_id="root", - session_id="root_session", - chat_model=default_config.chat_model if default_config else None, - mem_reader=default_config.mem_reader if default_config else None, - enable_mem_scheduler=default_config.enable_mem_scheduler - if default_config - else False, - mem_scheduler=default_config.mem_scheduler if default_config else None, - ) - else: - root_config = default_config.model_copy(deep=True) - root_config.user_id = "root" - root_config.session_id = "root_session" - - # Create persistent user manager BEFORE calling parent constructor - persistent_user_manager_client = PersistentUserManagerFactory.from_config( - config_factory=root_config.user_manager - ) - - # Initialize parent MOSCore with root config and persistent user manager - super().__init__(root_config, user_manager=persistent_user_manager_client) - - # Product-specific attributes - self.default_config = default_config - self.default_cube_config = default_cube_config - self.max_user_instances = max_user_instances - self.online_bot = online_bot - self.error_bot = error_bot - - # User-specific data structures - self.user_configs: dict[str, MOSConfig] = {} - self.user_cube_access: dict[str, set[str]] = {} # user_id -> set of cube_ids - self.user_chat_histories: dict[str, dict] = {} - - # Note: self.user_manager is now the persistent user manager from parent class - # No need for separate global_user_manager as they are the same instance - - # Initialize tiktoken for streaming - try: - # Use gpt2 encoding which is more stable and widely compatible - self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") - logger.info("tokenizer initialized successfully for streaming") - except Exception as e: - logger.warning( - f"Failed to initialize tokenizer, will use character-based chunking: {e}" - ) - self.tokenizer = None - - # Restore user instances from persistent storage - self._restore_user_instances(default_cube_config=default_cube_config) - logger.info(f"User instances restored successfully, now user is {self.mem_cubes.keys()}") - - def _restore_user_instances( - self, default_cube_config: GeneralMemCubeConfig | None = None - ) -> None: - """Restore user instances from persistent storage after service restart. - - Args: - default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None. - """ - try: - # Get all user configurations from persistent storage - user_configs = self.user_manager.list_user_configs(self.max_user_instances) - - # Get the raw database records for sorting by updated_at - session = self.user_manager._get_session() - try: - from memos.mem_user.persistent_user_manager import UserConfig - - db_configs = session.query(UserConfig).limit(self.max_user_instances).all() - # Create a mapping of user_id to updated_at timestamp - updated_at_map = {config.user_id: config.updated_at for config in db_configs} - - # Sort by updated_at timestamp (most recent first) and limit by max_instances - sorted_configs = sorted( - user_configs.items(), key=lambda x: updated_at_map.get(x[0], ""), reverse=True - )[: self.max_user_instances] - finally: - session.close() - - for user_id, config in sorted_configs: - if user_id != "root": # Skip root user - try: - # Store user config and cube access - self.user_configs[user_id] = config - self._load_user_cube_access(user_id) - - # Pre-load all cubes for this user with default config - self._preload_user_cubes(user_id, default_cube_config) - - logger.info( - f"Restored user configuration and pre-loaded cubes for {user_id}" - ) - - except Exception as e: - logger.error(f"Failed to restore user configuration for {user_id}: {e}") - - except Exception as e: - logger.error(f"Error during user instance restoration: {e}") - - def _initialize_cube_from_default_config( - self, cube_id: str, user_id: str, default_config: GeneralMemCubeConfig - ) -> GeneralMemCube | None: - """ - Initialize a cube from default configuration when cube path doesn't exist. - - Args: - cube_id (str): The cube ID to initialize. - user_id (str): The user ID for the cube. - default_config (GeneralMemCubeConfig): The default configuration to use. - """ - cube_config = default_config.model_copy(deep=True) - # Safely modify the graph_db user_name if it exists - if cube_config.text_mem.config.graph_db.config: - cube_config.text_mem.config.graph_db.config.user_name = ( - f"memos{user_id.replace('-', '')}" - ) - mem_cube = GeneralMemCube(config=cube_config) - return mem_cube - - def _preload_user_cubes( - self, user_id: str, default_cube_config: GeneralMemCubeConfig | None = None - ) -> None: - """Pre-load all cubes for a user into memory. - - Args: - user_id (str): The user ID to pre-load cubes for. - default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None. - """ - try: - # Get user's accessible cubes from persistent storage - accessible_cubes = self.user_manager.get_user_cubes(user_id) - - for cube in accessible_cubes: - if cube.cube_id not in self.mem_cubes: - try: - if cube.cube_path and os.path.exists(cube.cube_path): - # Pre-load cube with all memory types and default config - self.register_mem_cube( - cube.cube_path, - cube.cube_id, - user_id, - memory_types=["act_mem"] - if self.config.enable_activation_memory - else [], - default_config=default_cube_config, - ) - logger.info(f"Pre-loaded cube {cube.cube_id} for user {user_id}") - else: - logger.warning( - f"Cube path {cube.cube_path} does not exist for cube {cube.cube_id}, skipping pre-load" - ) - except Exception as e: - logger.error( - f"Failed to pre-load cube {cube.cube_id} for user {user_id}: {e}", - exc_info=True, - ) - - except Exception as e: - logger.error(f"Error pre-loading cubes for user {user_id}: {e}", exc_info=True) - - @timed - def _load_user_cubes( - self, user_id: str, default_cube_config: GeneralMemCubeConfig | None = None - ) -> None: - """Load all cubes for a user into memory. - - Args: - user_id (str): The user ID to load cubes for. - default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None. - """ - # Get user's accessible cubes from persistent storage - accessible_cubes = self.user_manager.get_user_cubes(user_id) - - for cube in accessible_cubes[:1]: - if cube.cube_id not in self.mem_cubes: - try: - if cube.cube_path and os.path.exists(cube.cube_path): - # Use MOSCore's register_mem_cube method directly with default config - # Only load act_mem since text_mem is stored in database - self.register_mem_cube( - cube.cube_path, - cube.cube_id, - user_id, - memory_types=["act_mem"], - default_config=default_cube_config, - ) - else: - logger.warning( - f"Cube path {cube.cube_path} does not exist for cube {cube.cube_id}, now init by default config" - ) - cube_obj = self._initialize_cube_from_default_config( - cube_id=cube.cube_id, - user_id=user_id, - default_config=default_cube_config, - ) - if cube_obj: - self.register_mem_cube( - cube_obj, - cube.cube_id, - user_id, - memory_types=[], - ) - else: - raise ValueError( - f"Failed to initialize default cube {cube.cube_id} for user {user_id}" - ) - except Exception as e: - logger.error(f"Failed to load cube {cube.cube_id} for user {user_id}: {e}") - logger.info(f"load user {user_id} cubes successfully") - - def _ensure_user_instance(self, user_id: str, max_instances: int | None = None) -> None: - """ - Ensure user configuration exists, creating it if necessary. - - Args: - user_id (str): The user ID - max_instances (int): Maximum instances to keep in memory (overrides class default) - """ - if user_id in self.user_configs: - return - - # Try to get config from persistent storage first - stored_config = self.user_manager.get_user_config(user_id) - if stored_config: - self.user_configs[user_id] = stored_config - self._load_user_cube_access(user_id) - else: - # Use default config - if not self.default_config: - raise ValueError(f"No configuration available for user {user_id}") - user_config = self.default_config.model_copy(deep=True) - user_config.user_id = user_id - user_config.session_id = f"{user_id}_session" - self.user_configs[user_id] = user_config - self._load_user_cube_access(user_id) - - # Apply LRU eviction if needed - max_instances = max_instances or self.max_user_instances - if len(self.user_configs) > max_instances: - # Remove least recently used instance (excluding root) - user_ids = [uid for uid in self.user_configs if uid != "root"] - if user_ids: - oldest_user_id = user_ids[0] - del self.user_configs[oldest_user_id] - if oldest_user_id in self.user_cube_access: - del self.user_cube_access[oldest_user_id] - logger.info(f"Removed least recently used user configuration: {oldest_user_id}") - - def _load_user_cube_access(self, user_id: str) -> None: - """Load user's cube access permissions.""" - try: - # Get user's accessible cubes from persistent storage - accessible_cubes = self.user_manager.get_user_cube_access(user_id) - self.user_cube_access[user_id] = set(accessible_cubes) - except Exception as e: - logger.warning(f"Failed to load cube access for user {user_id}: {e}") - self.user_cube_access[user_id] = set() - - def _get_user_config(self, user_id: str) -> MOSConfig: - """Get user configuration.""" - if user_id not in self.user_configs: - self._ensure_user_instance(user_id) - return self.user_configs[user_id] - - def _validate_user_cube_access(self, user_id: str, cube_id: str) -> None: - """Validate user has access to the cube.""" - if user_id not in self.user_cube_access: - self._load_user_cube_access(user_id) - - if cube_id not in self.user_cube_access.get(user_id, set()): - raise ValueError(f"User '{user_id}' does not have access to cube '{cube_id}'") - - def _validate_user_access(self, user_id: str, cube_id: str | None = None) -> None: - """Validate user access using MOSCore's built-in validation.""" - # Use MOSCore's built-in user validation - if cube_id: - self._validate_cube_access(user_id, cube_id) - else: - self._validate_user_exists(user_id) - - def _create_user_config(self, user_id: str, config: MOSConfig) -> MOSConfig: - """Create a new user configuration.""" - # Create a copy of config with the specific user_id - user_config = config.model_copy(deep=True) - user_config.user_id = user_id - user_config.session_id = f"{user_id}_session" - - # Save configuration to persistent storage - self.user_manager.save_user_config(user_id, user_config) - - return user_config - - def _get_or_create_user_config( - self, user_id: str, config: MOSConfig | None = None - ) -> MOSConfig: - """Get existing user config or create a new one.""" - if user_id in self.user_configs: - return self.user_configs[user_id] - - # Try to get config from persistent storage first - stored_config = self.user_manager.get_user_config(user_id) - if stored_config: - return self._create_user_config(user_id, stored_config) - - # Use provided config or default config - user_config = config or self.default_config - if not user_config: - raise ValueError(f"No configuration provided for user {user_id}") - - return self._create_user_config(user_id, user_config) - - def _build_system_prompt( - self, - memories_all: list[TextualMemoryItem], - base_prompt: str | None = None, - tone: str = "friendly", - verbosity: str = "mid", - ) -> str: - """ - Build custom system prompt for the user with memory references. - - Args: - user_id (str): The user ID. - memories (list[TextualMemoryItem]): The memories to build the system prompt. - - Returns: - str: The custom system prompt. - """ - # Build base prompt - # Add memory context if available - now = datetime.now() - formatted_date = now.strftime("%Y-%m-%d (%A)") - sys_body = get_memos_prompt( - date=formatted_date, tone=tone, verbosity=verbosity, mode="base" - ) - mem_block_o, mem_block_p = _format_mem_block(memories_all) - mem_block = mem_block_o + "\n" + mem_block_p - prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" - return ( - prefix - + sys_body - + "\n\n# Memories\n## PersonalMemory & OuterMemory (ordered)\n" - + mem_block - ) - - def _build_base_system_prompt( - self, - base_prompt: str | None = None, - tone: str = "friendly", - verbosity: str = "mid", - mode: str = "enhance", - ) -> str: - """ - Build base system prompt without memory references. - """ - now = datetime.now() - formatted_date = now.strftime("%Y-%m-%d (%A)") - sys_body = get_memos_prompt(date=formatted_date, tone=tone, verbosity=verbosity, mode=mode) - prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" - return prefix + sys_body - - def _build_memory_context( - self, - memories_all: list[TextualMemoryItem], - mode: str = "enhance", - ) -> str: - """ - Build memory context to be included in user message. - """ - if not memories_all: - return "" - - mem_block_o, mem_block_p = _format_mem_block(memories_all) - - if mode == "enhance": - return ( - "# Memories\n## PersonalMemory (ordered)\n" - + mem_block_p - + "\n## OuterMemory (ordered)\n" - + mem_block_o - + "\n\n" - ) - else: - mem_block = mem_block_o + "\n" + mem_block_p - return "# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + mem_block + "\n\n" - - def _build_enhance_system_prompt( - self, - user_id: str, - memories_all: list[TextualMemoryItem], - tone: str = "friendly", - verbosity: str = "mid", - ) -> str: - """ - Build enhance prompt for the user with memory references. - [DEPRECATED] Use _build_base_system_prompt and _build_memory_context instead. - """ - now = datetime.now() - formatted_date = now.strftime("%Y-%m-%d (%A)") - sys_body = get_memos_prompt( - date=formatted_date, tone=tone, verbosity=verbosity, mode="enhance" - ) - mem_block_o, mem_block_p = _format_mem_block(memories_all) - return ( - sys_body - + "\n\n# Memories\n## PersonalMemory (ordered)\n" - + mem_block_p - + "\n## OuterMemory (ordered)\n" - + mem_block_o - ) - - def _extract_references_from_response(self, response: str) -> tuple[str, list[dict]]: - """ - Extract reference information from the response and return clean text. - - Args: - response (str): The complete response text. - - Returns: - tuple[str, list[dict]]: A tuple containing: - - clean_text: Text with reference markers removed - - references: List of reference information - """ - import re - - try: - references = [] - # Pattern to match [refid:memoriesID] - pattern = r"\[(\d+):([^\]]+)\]" - - matches = re.findall(pattern, response) - for ref_number, memory_id in matches: - references.append({"memory_id": memory_id, "reference_number": int(ref_number)}) - - # Remove all reference markers from the text to get clean text - clean_text = re.sub(pattern, "", response) - - # Clean up any extra whitespace that might be left after removing markers - clean_text = re.sub(r"\s+", " ", clean_text).strip() - - return clean_text, references - except Exception as e: - logger.error(f"Error extracting references from response: {e}", exc_info=True) - return response, [] - - def _extract_struct_data_from_history(self, chat_data: list[dict]) -> dict: - """ - get struct message from chat-history - # TODO: @xcy make this more general - """ - system_content = "" - memory_content = "" - chat_history = [] - - for item in chat_data: - role = item.get("role") - content = item.get("content", "") - if role == "system": - parts = content.split("# Memories", 1) - system_content = parts[0].strip() - if len(parts) > 1: - memory_content = "# Memories" + parts[1].strip() - elif role in ("user", "assistant"): - chat_history.append({"role": role, "content": content}) - - if chat_history and chat_history[-1]["role"] == "assistant": - if len(chat_history) >= 2 and chat_history[-2]["role"] == "user": - chat_history = chat_history[:-2] - else: - chat_history = chat_history[:-1] - - return {"system": system_content, "memory": memory_content, "chat_history": chat_history} - - def _chunk_response_with_tiktoken( - self, response: str, chunk_size: int = 5 - ) -> Generator[str, None, None]: - """ - Chunk response using tiktoken for proper token-based streaming. - - Args: - response (str): The response text to chunk. - chunk_size (int): Number of tokens per chunk. - - Yields: - str: Chunked text pieces. - """ - if self.tokenizer: - # Use tiktoken for proper token-based chunking - tokens = self.tokenizer.encode(response) - - for i in range(0, len(tokens), chunk_size): - token_chunk = tokens[i : i + chunk_size] - chunk_text = self.tokenizer.decode(token_chunk) - yield chunk_text - else: - # Fallback to character-based chunking - char_chunk_size = chunk_size * 4 # Approximate character to token ratio - for i in range(0, len(response), char_chunk_size): - yield response[i : i + char_chunk_size] - - def _send_message_to_scheduler( - self, - user_id: str, - mem_cube_id: str, - query: str, - label: str, - ): - """ - Send message to scheduler. - args: - user_id: str, - mem_cube_id: str, - query: str, - """ - - if self.enable_mem_scheduler and (self.mem_scheduler is not None): - message_item = ScheduleMessageItem( - user_id=user_id, - mem_cube_id=mem_cube_id, - label=label, - content=query, - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.submit_messages(messages=[message_item]) - - async def _post_chat_processing( - self, - user_id: str, - cube_id: str, - query: str, - full_response: str, - system_prompt: str, - time_start: float, - time_end: float, - speed_improvement: float, - current_messages: list, - ) -> None: - """ - Asynchronous processing of logs, notifications and memory additions - """ - try: - logger.info( - f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}" - ) - logger.info(f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}") - - clean_response, extracted_references = self._extract_references_from_response( - full_response - ) - struct_message = self._extract_struct_data_from_history(current_messages) - logger.info(f"Extracted {len(extracted_references)} references from response") - - # Send chat report notifications asynchronously - if self.online_bot: - logger.info("Online Bot Open!") - try: - from memos.memos_tools.notification_utils import ( - send_online_bot_notification_async, - ) - - # Prepare notification data - chat_data = {"query": query, "user_id": user_id, "cube_id": cube_id} - chat_data.update( - { - "memory": struct_message["memory"], - "chat_history": struct_message["chat_history"], - "full_response": full_response, - } - ) - - system_data = { - "references": extracted_references, - "time_start": time_start, - "time_end": time_end, - "speed_improvement": speed_improvement, - } - - emoji_config = {"chat": "๐Ÿ’ฌ", "system_info": "๐Ÿ“Š"} - - await send_online_bot_notification_async( - online_bot=self.online_bot, - header_name="MemOS Chat Report", - sub_title_name="chat_with_references", - title_color="#00956D", - other_data1=chat_data, - other_data2=system_data, - emoji=emoji_config, - ) - except Exception as e: - logger.warning(f"Failed to send chat notification (async): {e}") - - self._send_message_to_scheduler( - user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_TASK_LABEL - ) - - self.add( - user_id=user_id, - messages=[ - { - "role": "user", - "content": query, - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - { - "role": "assistant", - "content": clean_response, # Store clean text without reference markers - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - ], - mem_cube_id=cube_id, - ) - - logger.info(f"Post-chat processing completed for user {user_id}") - - except Exception as e: - logger.error(f"Error in post-chat processing for user {user_id}: {e}", exc_info=True) - - def _start_post_chat_processing( - self, - user_id: str, - cube_id: str, - query: str, - full_response: str, - system_prompt: str, - time_start: float, - time_end: float, - speed_improvement: float, - current_messages: list, - ) -> None: - """ - Asynchronous processing of logs, notifications and memory additions, handle synchronous and asynchronous environments - """ - logger.info("Start post_chat_processing...") - - def run_async_in_thread(): - """Running asynchronous tasks in a new thread""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - self._post_chat_processing( - user_id=user_id, - cube_id=cube_id, - query=query, - full_response=full_response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=speed_improvement, - current_messages=current_messages, - ) - ) - finally: - loop.close() - except Exception as e: - logger.error( - f"Error in thread-based post-chat processing for user {user_id}: {e}", - exc_info=True, - ) - - try: - # Try to get the current event loop - asyncio.get_running_loop() - # Create task and store reference to prevent garbage collection - task = asyncio.create_task( - self._post_chat_processing( - user_id=user_id, - cube_id=cube_id, - query=query, - full_response=full_response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=speed_improvement, - current_messages=current_messages, - ) - ) - # Add exception handling for the background task - task.add_done_callback( - lambda t: ( - logger.error( - f"Error in background post-chat processing for user {user_id}: {t.exception()}", - exc_info=True, - ) - if t.exception() - else None - ) - ) - except RuntimeError: - # No event loop, run in a new thread with context propagation - thread = ContextThread( - target=run_async_in_thread, - name=f"PostChatProcessing-{user_id}", - # Set as a daemon thread to avoid blocking program exit - daemon=True, - ) - thread.start() - - def _filter_memories_by_threshold( - self, - memories: list[TextualMemoryItem], - threshold: float = 0.30, - min_num: int = 3, - memory_type: Literal["OuterMemory"] = "OuterMemory", - ) -> list[TextualMemoryItem]: - """ - Filter memories by threshold and type, at least min_num memories for Non-OuterMemory. - Args: - memories: list[TextualMemoryItem], - threshold: float, - min_num: int, - memory_type: Literal["OuterMemory"], - Returns: - list[TextualMemoryItem] - """ - sorted_memories = sorted(memories, key=lambda m: m.metadata.relativity, reverse=True) - filtered_person = [m for m in memories if m.metadata.memory_type != memory_type] - filtered_outer = [m for m in memories if m.metadata.memory_type == memory_type] - filtered = [] - per_memory_count = 0 - for m in sorted_memories: - if m.metadata.relativity >= threshold: - if m.metadata.memory_type != memory_type: - per_memory_count += 1 - filtered.append(m) - if len(filtered) < min_num: - filtered = filtered_person[:min_num] + filtered_outer[:min_num] - else: - if per_memory_count < min_num: - filtered += filtered_person[per_memory_count:min_num] - filtered_memory = sorted(filtered, key=lambda m: m.metadata.relativity, reverse=True) - return filtered_memory - - def register_mem_cube( - self, - mem_cube_name_or_path_or_object: str | GeneralMemCube, - mem_cube_id: str | None = None, - user_id: str | None = None, - memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, - default_config: GeneralMemCubeConfig | None = None, - ) -> None: - """ - Register a MemCube with the MOS. - - Args: - mem_cube_name_or_path_or_object (str | GeneralMemCube): The name, path, or GeneralMemCube object to register. - mem_cube_id (str, optional): The identifier for the MemCube. If not provided, a default ID is used. - user_id (str, optional): The user ID to register the cube for. - memory_types (list[str], optional): List of memory types to load. - If None, loads all available memory types. - Options: ["text_mem", "act_mem", "para_mem"] - default_config (GeneralMemCubeConfig, optional): Default configuration for the cube. - """ - # Handle different input types - if isinstance(mem_cube_name_or_path_or_object, GeneralMemCube): - # Direct GeneralMemCube object provided - mem_cube = mem_cube_name_or_path_or_object - if mem_cube_id is None: - mem_cube_id = f"cube_{id(mem_cube)}" # Generate a unique ID - else: - # String path provided - mem_cube_name_or_path = mem_cube_name_or_path_or_object - if mem_cube_id is None: - mem_cube_id = mem_cube_name_or_path - - if mem_cube_id in self.mem_cubes: - logger.info(f"MemCube with ID {mem_cube_id} already in MOS, skip install.") - return - - # Create MemCube from path - time_start = time.time() - if os.path.exists(mem_cube_name_or_path): - mem_cube = GeneralMemCube.init_from_dir( - mem_cube_name_or_path, memory_types, default_config - ) - logger.info( - f"time register_mem_cube: init_from_dir time is: {time.time() - time_start}" - ) - else: - logger.warning( - f"MemCube {mem_cube_name_or_path} does not exist, try to init from remote repo." - ) - mem_cube = GeneralMemCube.init_from_remote_repo( - mem_cube_name_or_path, memory_types=memory_types, default_config=default_config - ) - - # Register the MemCube - logger.info( - f"Registering MemCube {mem_cube_id} with cube config {mem_cube.config.model_dump(mode='json')}" - ) - time_start = time.time() - self.mem_cubes[mem_cube_id] = mem_cube - time_end = time.time() - logger.info(f"time register_mem_cube: add mem_cube time is: {time_end - time_start}") - - def user_register( - self, - user_id: str, - user_name: str | None = None, - config: MOSConfig | None = None, - interests: str | None = None, - default_mem_cube: GeneralMemCube | None = None, - default_cube_config: GeneralMemCubeConfig | None = None, - mem_cube_id: str | None = None, - ) -> dict[str, str]: - """Register a new user with configuration and default cube. - - Args: - user_id (str): The user ID for registration. - user_name (str): The user name for registration. - config (MOSConfig | None, optional): User-specific configuration. Defaults to None. - interests (str | None, optional): User interests as string. Defaults to None. - default_mem_cube (GeneralMemCube | None, optional): Default memory cube. Defaults to None. - default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None. - - Returns: - dict[str, str]: Registration result with status and message. - """ - try: - # Use provided config or default config - user_config = config or self.default_config - if not user_config: - return { - "status": "error", - "message": "No configuration provided for user registration", - } - if not user_name: - user_name = user_id - - # Create user with configuration using persistent user manager - self.user_manager.create_user_with_config(user_id, user_config, UserRole.USER, user_id) - - # Create user configuration - user_config = self._create_user_config(user_id, user_config) - - # Create a default cube for the user using MOSCore's methods - default_cube_name = f"{user_name}_{user_id}_default_cube" - mem_cube_name_or_path = os.path.join(CUBE_PATH, default_cube_name) - default_cube_id = self.create_cube_for_user( - cube_name=default_cube_name, - owner_id=user_id, - cube_path=mem_cube_name_or_path, - cube_id=mem_cube_id, - ) - time_start = time.time() - if default_mem_cube: - try: - default_mem_cube.dump(mem_cube_name_or_path, memory_types=[]) - except Exception as e: - logger.error(f"Failed to dump default cube: {e}") - time_end = time.time() - logger.info(f"time user_register: dump default cube time is: {time_end - time_start}") - # Register the default cube with MOS - self.register_mem_cube( - mem_cube_name_or_path_or_object=default_mem_cube, - mem_cube_id=default_cube_id, - user_id=user_id, - memory_types=["act_mem"] if self.config.enable_activation_memory else [], - default_config=default_cube_config, # use default cube config - ) - - # Add interests to the default cube if provided - if interests: - self.add(memory_content=interests, mem_cube_id=default_cube_id, user_id=user_id) - - return { - "status": "success", - "message": f"User {user_name} registered successfully with default cube {default_cube_id}", - "user_id": user_id, - "default_cube_id": default_cube_id, - } - - except Exception as e: - return {"status": "error", "message": f"Failed to register user: {e!s}"} - - def _get_further_suggestion(self, message: MessageList | None = None) -> list[str]: - """Get further suggestion prompt.""" - try: - dialogue_info = "\n".join([f"{msg['role']}: {msg['content']}" for msg in message[-2:]]) - further_suggestion_prompt = FURTHER_SUGGESTION_PROMPT.format(dialogue=dialogue_info) - message_list = [{"role": "system", "content": further_suggestion_prompt}] - response = self.chat_llm.generate(message_list) - clean_response = clean_json_response(response) - response_json = json.loads(clean_response) - return response_json["query"] - except Exception as e: - logger.error(f"Error getting further suggestion: {e}", exc_info=True) - return [] - - def get_suggestion_query( - self, user_id: str, language: str = "zh", message: MessageList | None = None - ) -> list[str]: - """Get suggestion query from LLM. - Args: - user_id (str): User ID. - language (str): Language for suggestions ("zh" or "en"). - - Returns: - list[str]: The suggestion query list. - """ - if message: - further_suggestion = self._get_further_suggestion(message) - return further_suggestion - if language == "zh": - suggestion_prompt = SUGGESTION_QUERY_PROMPT_ZH - else: # English - suggestion_prompt = SUGGESTION_QUERY_PROMPT_EN - text_mem_result = super().search("my recently memories", user_id=user_id, top_k=3)[ - "text_mem" - ] - if text_mem_result: - memories = "\n".join([m.memory[:200] for m in text_mem_result[0]["memories"]]) - else: - memories = "" - message_list = [{"role": "system", "content": suggestion_prompt.format(memories=memories)}] - response = self.chat_llm.generate(message_list) - clean_response = clean_json_response(response) - response_json = json.loads(clean_response) - return response_json["query"] - - def chat( - self, - query: str, - user_id: str, - cube_id: str | None = None, - history: MessageList | None = None, - base_prompt: str | None = None, - internet_search: bool = False, - moscube: bool = False, - top_k: int = 10, - threshold: float = 0.5, - session_id: str | None = None, - ) -> str: - """ - Chat with LLM with memory references and complete response. - """ - self._load_user_cubes(user_id, self.default_cube_config) - time_start = time.time() - memories_result = super().search( - query, - user_id, - install_cube_ids=[cube_id] if cube_id else None, - top_k=top_k, - mode="fine", - internet_search=internet_search, - moscube=moscube, - session_id=session_id, - )["text_mem"] - - memories_list = [] - if memories_result: - memories_list = memories_result[0]["memories"] - memories_list = self._filter_memories_by_threshold(memories_list, threshold) - new_memories_list = [] - for m in memories_list: - m.metadata.embedding = [] - new_memories_list.append(m) - memories_list = new_memories_list - - system_prompt = super()._build_system_prompt(memories_list, base_prompt) - if history is not None: - # Use the provided history (even if it's empty) - history_info = history[-20:] - else: - # Fall back to internal chat_history - if user_id not in self.chat_history_manager: - self._register_chat_history(user_id, session_id) - history_info = self.chat_history_manager[user_id].chat_history[-20:] - current_messages = [ - {"role": "system", "content": system_prompt}, - *history_info, - {"role": "user", "content": query}, - ] - logger.info("Start to get final answer...") - response = self.chat_llm.generate(current_messages) - time_end = time.time() - self._start_post_chat_processing( - user_id=user_id, - cube_id=cube_id, - query=query, - full_response=response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=0.0, - current_messages=current_messages, - ) - return response, memories_list - - def chat_with_references( - self, - query: str, - user_id: str, - cube_id: str | None = None, - history: MessageList | None = None, - top_k: int = 20, - internet_search: bool = False, - moscube: bool = False, - session_id: str | None = None, - ) -> Generator[str, None, None]: - """ - Chat with LLM with memory references and streaming output. - - Args: - query (str): Query string. - user_id (str): User ID. - cube_id (str, optional): Custom cube ID for user. - history (MessageList, optional): Chat history. - - Returns: - Generator[str, None, None]: The response string generator with reference processing. - """ - - self._load_user_cubes(user_id, self.default_cube_config) - time_start = time.time() - memories_list = [] - yield f"data: {json.dumps({'type': 'status', 'data': '0'})}\n\n" - memories_result = super().search( - query, - user_id, - install_cube_ids=[cube_id] if cube_id else None, - top_k=top_k, - mode="fine", - internet_search=internet_search, - moscube=moscube, - session_id=session_id, - )["text_mem"] - - yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n" - search_time_end = time.time() - logger.info( - f"time chat: search text_mem time user_id: {user_id} time is: {search_time_end - time_start}" - ) - self._send_message_to_scheduler( - user_id=user_id, mem_cube_id=cube_id, query=query, label=QUERY_TASK_LABEL - ) - if memories_result: - memories_list = memories_result[0]["memories"] - memories_list = self._filter_memories_by_threshold(memories_list) - - reference = prepare_reference_data(memories_list) - yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" - # Build custom system prompt with relevant memories) - system_prompt = self._build_enhance_system_prompt(user_id, memories_list) - # Get chat history - if user_id not in self.chat_history_manager: - self._register_chat_history(user_id, session_id) - - chat_history = self.chat_history_manager[user_id] - if history is not None: - chat_history.chat_history = history[-20:] - current_messages = [ - {"role": "system", "content": system_prompt}, - *chat_history.chat_history, - {"role": "user", "content": query}, - ] - logger.info( - f"user_id: {user_id}, cube_id: {cube_id}, current_system_prompt: {system_prompt}" - ) - yield f"data: {json.dumps({'type': 'status', 'data': '2'})}\n\n" - # Generate response with custom prompt - past_key_values = None - response_stream = None - if self.config.enable_activation_memory: - # Handle activation memory (copy MOSCore logic) - for mem_cube_id, mem_cube in self.mem_cubes.items(): - if mem_cube.act_mem and mem_cube_id == cube_id: - kv_cache = next(iter(mem_cube.act_mem.get_all()), None) - past_key_values = ( - kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None - ) - if past_key_values is not None: - logger.info("past_key_values is not None will apply to chat") - else: - logger.info("past_key_values is None will not apply to chat") - break - if self.config.chat_model.backend == "huggingface": - response_stream = self.chat_llm.generate_stream( - current_messages, past_key_values=past_key_values - ) - elif self.config.chat_model.backend == "vllm": - response_stream = self.chat_llm.generate_stream(current_messages) - else: - if self.config.chat_model.backend in ["huggingface", "vllm", "openai"]: - response_stream = self.chat_llm.generate_stream(current_messages) - else: - response_stream = self.chat_llm.generate(current_messages) - - time_end = time.time() - chat_time_end = time.time() - logger.info( - f"time chat: chat time user_id: {user_id} time is: {chat_time_end - search_time_end}" - ) - # Simulate streaming output with proper reference handling using tiktoken - - # Initialize buffer for streaming - buffer = "" - full_response = "" - token_count = 0 - # Use tiktoken for proper token-based chunking - if self.config.chat_model.backend not in ["huggingface", "vllm", "openai"]: - # For non-huggingface backends, we need to collect the full response first - full_response_text = "" - for chunk in response_stream: - if chunk in ["", ""]: - continue - full_response_text += chunk - response_stream = self._chunk_response_with_tiktoken(full_response_text, chunk_size=5) - for chunk in response_stream: - if chunk in ["", ""]: - continue - token_count += 1 - buffer += chunk - full_response += chunk - - # Process buffer to ensure complete reference tags - processed_chunk, remaining_buffer = process_streaming_references_complete(buffer) - - if processed_chunk: - chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n" - yield chunk_data - buffer = remaining_buffer - - # Process any remaining buffer - if buffer: - processed_chunk, remaining_buffer = process_streaming_references_complete(buffer) - if processed_chunk: - chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n" - yield chunk_data - - # set kvcache improve speed - speed_improvement = round(float((len(system_prompt) / 2) * 0.0048 + 44.5), 1) - total_time = round(float(time_end - time_start), 1) - - yield f"data: {json.dumps({'type': 'time', 'data': {'total_time': total_time, 'speed_improvement': f'{speed_improvement}%'}})}\n\n" - # get further suggestion - current_messages.append({"role": "assistant", "content": full_response}) - further_suggestion = self._get_further_suggestion(current_messages) - logger.info(f"further_suggestion: {further_suggestion}") - yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion})}\n\n" - yield f"data: {json.dumps({'type': 'end'})}\n\n" - - # Asynchronous processing of logs, notifications and memory additions - self._start_post_chat_processing( - user_id=user_id, - cube_id=cube_id, - query=query, - full_response=full_response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=speed_improvement, - current_messages=current_messages, - ) - - def get_all( - self, - user_id: str, - memory_type: Literal["text_mem", "act_mem", "param_mem", "para_mem"], - mem_cube_ids: list[str] | None = None, - ) -> list[dict[str, Any]]: - """Get all memory items for a user. - - Args: - user_id (str): The ID of the user. - cube_id (str | None, optional): The ID of the cube. Defaults to None. - memory_type (Literal["text_mem", "act_mem", "param_mem"]): The type of memory to get. - - Returns: - list[dict[str, Any]]: A list of memory items with cube_id and memories structure. - """ - - # Load user cubes if not already loaded - self._load_user_cubes(user_id, self.default_cube_config) - time_start = time.time() - memory_list = super().get_all( - mem_cube_id=mem_cube_ids[0] if mem_cube_ids else None, user_id=user_id - )[memory_type] - get_all_time_end = time.time() - logger.info( - f"time get_all: get_all time user_id: {user_id} time is: {get_all_time_end - time_start}" - ) - reformat_memory_list = [] - if memory_type == "text_mem": - for memory in memory_list: - memories = remove_embedding_recursive(memory["memories"]) - custom_type_ratios = { - "WorkingMemory": 0.20, - "LongTermMemory": 0.40, - "UserMemory": 0.40, - } - tree_result, node_type_count = convert_graph_to_tree_forworkmem( - memories, target_node_count=200, type_ratios=custom_type_ratios - ) - # Ensure all node IDs are unique in the tree structure - tree_result = ensure_unique_tree_ids(tree_result) - memories_filtered = filter_nodes_by_tree_ids(tree_result, memories) - children = tree_result["children"] - children_sort = sort_children_by_memory_type(children) - tree_result["children"] = children_sort - memories_filtered["tree_structure"] = tree_result - reformat_memory_list.append( - { - "cube_id": memory["cube_id"], - "memories": [memories_filtered], - "memory_statistics": node_type_count, - } - ) - elif memory_type == "act_mem": - memories_list = [] - act_mem_params = self.mem_cubes[mem_cube_ids[0]].act_mem.get_all() - if act_mem_params: - memories_data = act_mem_params[0].model_dump() - records = memories_data.get("records", []) - for record in records["text_memories"]: - memories_list.append( - { - "id": memories_data["id"], - "text": record, - "create_time": records["timestamp"], - "size": random.randint(1, 20), - "modify_times": 1, - } - ) - reformat_memory_list.append( - { - "cube_id": "xxxxxxxxxxxxxxxx" if not mem_cube_ids else mem_cube_ids[0], - "memories": memories_list, - } - ) - elif memory_type == "para_mem": - act_mem_params = self.mem_cubes[mem_cube_ids[0]].act_mem.get_all() - logger.info(f"act_mem_params: {act_mem_params}") - reformat_memory_list.append( - { - "cube_id": "xxxxxxxxxxxxxxxx" if not mem_cube_ids else mem_cube_ids[0], - "memories": act_mem_params[0].model_dump(), - } - ) - make_format_time_end = time.time() - logger.info( - f"time get_all: make_format time user_id: {user_id} time is: {make_format_time_end - get_all_time_end}" - ) - return reformat_memory_list - - def _get_subgraph( - self, query: str, mem_cube_id: str, user_id: str | None = None, top_k: int = 5 - ) -> list[dict[str, Any]]: - result = {"para_mem": [], "act_mem": [], "text_mem": []} - if self.config.enable_textual_memory and self.mem_cubes[mem_cube_id].text_mem: - result["text_mem"].append( - { - "cube_id": mem_cube_id, - "memories": self.mem_cubes[mem_cube_id].text_mem.get_relevant_subgraph( - query, top_k=top_k - ), - } - ) - return result - - def get_subgraph( - self, - user_id: str, - query: str, - mem_cube_ids: list[str] | None = None, - top_k: int = 20, - ) -> list[dict[str, Any]]: - """Get all memory items for a user. - - Args: - user_id (str): The ID of the user. - cube_id (str | None, optional): The ID of the cube. Defaults to None. - mem_cube_ids (list[str], optional): The IDs of the cubes. Defaults to None. - - Returns: - list[dict[str, Any]]: A list of memory items with cube_id and memories structure. - """ - - # Load user cubes if not already loaded - self._load_user_cubes(user_id, self.default_cube_config) - memory_list = self._get_subgraph( - query=query, mem_cube_id=mem_cube_ids[0], user_id=user_id, top_k=top_k - )["text_mem"] - reformat_memory_list = [] - for memory in memory_list: - memories = remove_embedding_recursive(memory["memories"]) - custom_type_ratios = {"WorkingMemory": 0.20, "LongTermMemory": 0.40, "UserMemory": 0.4} - tree_result, node_type_count = convert_graph_to_tree_forworkmem( - memories, target_node_count=150, type_ratios=custom_type_ratios - ) - # Ensure all node IDs are unique in the tree structure - tree_result = ensure_unique_tree_ids(tree_result) - memories_filtered = filter_nodes_by_tree_ids(tree_result, memories) - children = tree_result["children"] - children_sort = sort_children_by_memory_type(children) - tree_result["children"] = children_sort - memories_filtered["tree_structure"] = tree_result - reformat_memory_list.append( - { - "cube_id": memory["cube_id"], - "memories": [memories_filtered], - "memory_statistics": node_type_count, - } - ) - - return reformat_memory_list - - def search( - self, - query: str, - user_id: str, - install_cube_ids: list[str] | None = None, - top_k: int = 10, - mode: Literal["fast", "fine"] = "fast", - session_id: str | None = None, - ): - """Search memories for a specific user.""" - - # Load user cubes if not already loaded - time_start = time.time() - self._load_user_cubes(user_id, self.default_cube_config) - load_user_cubes_time_end = time.time() - logger.info( - f"time search: load_user_cubes time user_id: {user_id} time is: {load_user_cubes_time_end - time_start}" - ) - search_result = super().search( - query, user_id, install_cube_ids, top_k, mode=mode, session_id=session_id - ) - search_time_end = time.time() - logger.info( - f"time search: search text_mem time user_id: {user_id} time is: {search_time_end - load_user_cubes_time_end}" - ) - text_memory_list = search_result["text_mem"] - reformat_memory_list = [] - for memory in text_memory_list: - memories_list = [] - for data in memory["memories"]: - memories = data.model_dump() - memories["ref_id"] = f"[{memories['id'].split('-')[0]}]" - memories["metadata"]["embedding"] = [] - memories["metadata"]["sources"] = [] - memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]" - memories["metadata"]["id"] = memories["id"] - memories["metadata"]["memory"] = memories["memory"] - memories_list.append(memories) - reformat_memory_list.append({"cube_id": memory["cube_id"], "memories": memories_list}) - logger.info(f"search memory list is : {reformat_memory_list}") - search_result["text_mem"] = reformat_memory_list - - pref_memory_list = search_result["pref_mem"] - reformat_pref_memory_list = [] - for memory in pref_memory_list: - memories_list = [] - for data in memory["memories"]: - memories = data.model_dump() - memories["ref_id"] = f"[{memories['id'].split('-')[0]}]" - memories["metadata"]["embedding"] = [] - memories["metadata"]["sources"] = [] - memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]" - memories["metadata"]["id"] = memories["id"] - memories["metadata"]["memory"] = memories["memory"] - memories_list.append(memories) - reformat_pref_memory_list.append( - {"cube_id": memory["cube_id"], "memories": memories_list} - ) - search_result["pref_mem"] = reformat_pref_memory_list - time_end = time.time() - logger.info( - f"time search: total time for user_id: {user_id} time is: {time_end - time_start}" - ) - return search_result - - def add( - self, - user_id: str, - messages: MessageList | None = None, - memory_content: str | None = None, - doc_path: str | None = None, - mem_cube_id: str | None = None, - source: str | None = None, - user_profile: bool = False, - session_id: str | None = None, - task_id: str | None = None, # Add task_id parameter - ): - """Add memory for a specific user.""" - - # Load user cubes if not already loaded - self._load_user_cubes(user_id, self.default_cube_config) - result = super().add( - messages, - memory_content, - doc_path, - mem_cube_id, - user_id, - session_id=session_id, - task_id=task_id, - ) - if user_profile: - try: - user_interests = memory_content.split("'userInterests': '")[1].split("', '")[0] - user_interests = user_interests.replace(",", " ") - user_profile_memories = self.mem_cubes[ - mem_cube_id - ].text_mem.internet_retriever.retrieve_from_internet(query=user_interests, top_k=5) - for memory in user_profile_memories: - self.mem_cubes[mem_cube_id].text_mem.add(memory) - except Exception as e: - logger.error( - f"Failed to retrieve user profile: {e}, memory_content: {memory_content}" - ) - - return result - - def list_users(self) -> list: - """List all registered users.""" - return self.user_manager.list_users() - - def get_user_info(self, user_id: str) -> dict: - """Get user information including accessible cubes.""" - # Use MOSCore's built-in user validation - # Validate user access - self._validate_user_access(user_id) - - result = super().get_user_info() - - return result - - def share_cube_with_user(self, cube_id: str, owner_user_id: str, target_user_id: str) -> bool: - """Share a cube with another user.""" - # Use MOSCore's built-in cube access validation - self._validate_cube_access(owner_user_id, cube_id) - - result = super().share_cube_with_user(cube_id, target_user_id) - - return result - - def clear_user_chat_history(self, user_id: str) -> None: - """Clear chat history for a specific user.""" - # Validate user access - self._validate_user_access(user_id) - - super().clear_messages(user_id) - - def update_user_config(self, user_id: str, config: MOSConfig) -> bool: - """Update user configuration. - - Args: - user_id (str): The user ID. - config (MOSConfig): The new configuration. - - Returns: - bool: True if successful, False otherwise. - """ - try: - # Save to persistent storage - success = self.user_manager.save_user_config(user_id, config) - if success: - # Update in-memory config - self.user_configs[user_id] = config - logger.info(f"Updated configuration for user {user_id}") - - return success - except Exception as e: - logger.error(f"Failed to update user config for {user_id}: {e}") - return False - - def get_user_config(self, user_id: str) -> MOSConfig | None: - """Get user configuration. - - Args: - user_id (str): The user ID. - - Returns: - MOSConfig | None: The user's configuration or None if not found. - """ - return self.user_manager.get_user_config(user_id) - - def get_active_user_count(self) -> int: - """Get the number of active user configurations in memory.""" - return len(self.user_configs) - - def get_user_instance_info(self) -> dict[str, Any]: - """Get information about user configurations in memory.""" - return { - "active_instances": len(self.user_configs), - "max_instances": self.max_user_instances, - "user_ids": list(self.user_configs.keys()), - "lru_order": list(self.user_configs.keys()), # OrderedDict maintains insertion order - } diff --git a/src/memos/mem_os/product_server.py b/src/memos/mem_os/product_server.py deleted file mode 100644 index 80aefea85..000000000 --- a/src/memos/mem_os/product_server.py +++ /dev/null @@ -1,457 +0,0 @@ -import asyncio -import time - -from datetime import datetime -from typing import Literal - -from memos.context.context import ContextThread -from memos.llms.base import BaseLLM -from memos.log import get_logger -from memos.mem_cube.navie import NaiveMemCube -from memos.mem_os.product import _format_mem_block -from memos.mem_reader.base import BaseMemReader -from memos.memories.textual.item import TextualMemoryItem -from memos.templates.mos_prompts import ( - get_memos_prompt, -) -from memos.types import MessageList - - -logger = get_logger(__name__) - - -class MOSServer: - def __init__( - self, - mem_reader: BaseMemReader | None = None, - llm: BaseLLM | None = None, - online_bot: bool = False, - ): - self.mem_reader = mem_reader - self.chat_llm = llm - self.online_bot = online_bot - - def chat( - self, - query: str, - user_id: str, - cube_id: str | None = None, - mem_cube: NaiveMemCube | None = None, - history: MessageList | None = None, - base_prompt: str | None = None, - internet_search: bool = False, - moscube: bool = False, - top_k: int = 10, - threshold: float = 0.5, - session_id: str | None = None, - ) -> str: - """ - Chat with LLM with memory references and complete response. - """ - time_start = time.time() - memories_result = mem_cube.text_mem.search( - query=query, - user_name=cube_id, - top_k=top_k, - mode="fine", - manual_close_internet=not internet_search, - moscube=moscube, - info={ - "user_id": user_id, - "session_id": session_id, - "chat_history": history, - }, - ) - - memories_list = [] - if memories_result: - memories_list = self._filter_memories_by_threshold(memories_result, threshold) - new_memories_list = [] - for m in memories_list: - m.metadata.embedding = [] - new_memories_list.append(m) - memories_list = new_memories_list - system_prompt = self._build_system_prompt(memories_list, base_prompt) - - history_info = [] - if history: - history_info = history[-20:] - current_messages = [ - {"role": "system", "content": system_prompt}, - *history_info, - {"role": "user", "content": query}, - ] - response = self.chat_llm.generate(current_messages) - time_end = time.time() - self._start_post_chat_processing( - user_id=user_id, - cube_id=cube_id, - session_id=session_id, - query=query, - full_response=response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=0.0, - current_messages=current_messages, - mem_cube=mem_cube, - history=history, - ) - return response, memories_list - - def add( - self, - user_id: str, - cube_id: str, - mem_cube: NaiveMemCube, - messages: MessageList, - session_id: str | None = None, - history: MessageList | None = None, - ) -> list[str]: - memories = self.mem_reader.get_memory( - [messages], - type="chat", - info={ - "user_id": user_id, - "session_id": session_id, - "chat_history": history, - }, - ) - flattened_memories = [mm for m in memories for mm in m] - mem_id_list: list[str] = mem_cube.text_mem.add( - flattened_memories, - user_name=cube_id, - ) - return mem_id_list - - def search( - self, - user_id: str, - cube_id: str, - session_id: str | None = None, - ) -> None: - NotImplementedError("Not implemented") - - def _filter_memories_by_threshold( - self, - memories: list[TextualMemoryItem], - threshold: float = 0.30, - min_num: int = 3, - memory_type: Literal["OuterMemory"] = "OuterMemory", - ) -> list[TextualMemoryItem]: - """ - Filter memories by threshold and type, at least min_num memories for Non-OuterMemory. - Args: - memories: list[TextualMemoryItem], - threshold: float, - min_num: int, - memory_type: Literal["OuterMemory"], - Returns: - list[TextualMemoryItem] - """ - sorted_memories = sorted(memories, key=lambda m: m.metadata.relativity, reverse=True) - filtered_person = [m for m in memories if m.metadata.memory_type != memory_type] - filtered_outer = [m for m in memories if m.metadata.memory_type == memory_type] - filtered = [] - per_memory_count = 0 - for m in sorted_memories: - if m.metadata.relativity >= threshold: - if m.metadata.memory_type != memory_type: - per_memory_count += 1 - filtered.append(m) - if len(filtered) < min_num: - filtered = filtered_person[:min_num] + filtered_outer[:min_num] - else: - if per_memory_count < min_num: - filtered += filtered_person[per_memory_count:min_num] - filtered_memory = sorted(filtered, key=lambda m: m.metadata.relativity, reverse=True) - return filtered_memory - - def _build_base_system_prompt( - self, - base_prompt: str | None = None, - tone: str = "friendly", - verbosity: str = "mid", - mode: str = "enhance", - ) -> str: - """ - Build base system prompt without memory references. - """ - now = datetime.now() - formatted_date = now.strftime("%Y-%m-%d (%A)") - sys_body = get_memos_prompt(date=formatted_date, tone=tone, verbosity=verbosity, mode=mode) - prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" - return prefix + sys_body - - def _build_system_prompt( - self, - memories: list[TextualMemoryItem] | list[str] | None = None, - base_prompt: str | None = None, - **kwargs, - ) -> str: - """Build system prompt with optional memories context.""" - if base_prompt is None: - base_prompt = ( - "You are a knowledgeable and helpful AI assistant. " - "You have access to conversation memories that help you provide more personalized responses. " - "Use the memories to understand the user's context, preferences, and past interactions. " - "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories." - ) - - memory_context = "" - if memories: - memory_list = [] - for i, memory in enumerate(memories, 1): - if isinstance(memory, TextualMemoryItem): - text_memory = memory.memory - else: - if not isinstance(memory, str): - logger.error("Unexpected memory type.") - text_memory = memory - memory_list.append(f"{i}. {text_memory}") - memory_context = "\n".join(memory_list) - - if "{memories}" in base_prompt: - return base_prompt.format(memories=memory_context) - elif base_prompt and memories: - # For backward compatibility, append memories if no placeholder is found - memory_context_with_header = "\n\n## Memories:\n" + memory_context - return base_prompt + memory_context_with_header - return base_prompt - - def _build_memory_context( - self, - memories_all: list[TextualMemoryItem], - mode: str = "enhance", - ) -> str: - """ - Build memory context to be included in user message. - """ - if not memories_all: - return "" - - mem_block_o, mem_block_p = _format_mem_block(memories_all) - - if mode == "enhance": - return ( - "# Memories\n## PersonalMemory (ordered)\n" - + mem_block_p - + "\n## OuterMemory (ordered)\n" - + mem_block_o - + "\n\n" - ) - else: - mem_block = mem_block_o + "\n" + mem_block_p - return "# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + mem_block + "\n\n" - - def _extract_references_from_response(self, response: str) -> tuple[str, list[dict]]: - """ - Extract reference information from the response and return clean text. - - Args: - response (str): The complete response text. - - Returns: - tuple[str, list[dict]]: A tuple containing: - - clean_text: Text with reference markers removed - - references: List of reference information - """ - import re - - try: - references = [] - # Pattern to match [refid:memoriesID] - pattern = r"\[(\d+):([^\]]+)\]" - - matches = re.findall(pattern, response) - for ref_number, memory_id in matches: - references.append({"memory_id": memory_id, "reference_number": int(ref_number)}) - - # Remove all reference markers from the text to get clean text - clean_text = re.sub(pattern, "", response) - - # Clean up any extra whitespace that might be left after removing markers - clean_text = re.sub(r"\s+", " ", clean_text).strip() - - return clean_text, references - except Exception as e: - logger.error(f"Error extracting references from response: {e}", exc_info=True) - return response, [] - - async def _post_chat_processing( - self, - user_id: str, - cube_id: str, - query: str, - full_response: str, - system_prompt: str, - time_start: float, - time_end: float, - speed_improvement: float, - current_messages: list, - mem_cube: NaiveMemCube | None = None, - session_id: str | None = None, - history: MessageList | None = None, - ) -> None: - """ - Asynchronous processing of logs, notifications and memory additions - """ - try: - logger.info( - f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}" - ) - logger.info(f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}") - - clean_response, extracted_references = self._extract_references_from_response( - full_response - ) - logger.info(f"Extracted {len(extracted_references)} references from response") - - # Send chat report notifications asynchronously - if self.online_bot: - try: - from memos.memos_tools.notification_utils import ( - send_online_bot_notification_async, - ) - - # Prepare notification data - chat_data = { - "query": query, - "user_id": user_id, - "cube_id": cube_id, - "system_prompt": system_prompt, - "full_response": full_response, - } - - system_data = { - "references": extracted_references, - "time_start": time_start, - "time_end": time_end, - "speed_improvement": speed_improvement, - } - - emoji_config = {"chat": "๐Ÿ’ฌ", "system_info": "๐Ÿ“Š"} - - await send_online_bot_notification_async( - online_bot=self.online_bot, - header_name="MemOS Chat Report", - sub_title_name="chat_with_references", - title_color="#00956D", - other_data1=chat_data, - other_data2=system_data, - emoji=emoji_config, - ) - except Exception as e: - logger.warning(f"Failed to send chat notification (async): {e}") - - self.add( - user_id=user_id, - cube_id=cube_id, - mem_cube=mem_cube, - session_id=session_id, - history=history, - messages=[ - { - "role": "user", - "content": query, - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - { - "role": "assistant", - "content": clean_response, # Store clean text without reference markers - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - ], - ) - - logger.info(f"Post-chat processing completed for user {user_id}") - - except Exception as e: - logger.error(f"Error in post-chat processing for user {user_id}: {e}", exc_info=True) - - def _start_post_chat_processing( - self, - user_id: str, - cube_id: str, - query: str, - full_response: str, - system_prompt: str, - time_start: float, - time_end: float, - speed_improvement: float, - current_messages: list, - mem_cube: NaiveMemCube | None = None, - session_id: str | None = None, - history: MessageList | None = None, - ) -> None: - """ - Asynchronous processing of logs, notifications and memory additions, handle synchronous and asynchronous environments - """ - - def run_async_in_thread(): - """Running asynchronous tasks in a new thread""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - self._post_chat_processing( - user_id=user_id, - cube_id=cube_id, - query=query, - full_response=full_response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=speed_improvement, - current_messages=current_messages, - mem_cube=mem_cube, - session_id=session_id, - history=history, - ) - ) - finally: - loop.close() - except Exception as e: - logger.error( - f"Error in thread-based post-chat processing for user {user_id}: {e}", - exc_info=True, - ) - - try: - # Try to get the current event loop - asyncio.get_running_loop() - # Create task and store reference to prevent garbage collection - task = asyncio.create_task( - self._post_chat_processing( - user_id=user_id, - cube_id=cube_id, - query=query, - full_response=full_response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=speed_improvement, - current_messages=current_messages, - ) - ) - # Add exception handling for the background task - task.add_done_callback( - lambda t: ( - logger.error( - f"Error in background post-chat processing for user {user_id}: {t.exception()}", - exc_info=True, - ) - if t.exception() - else None - ) - ) - except RuntimeError: - # No event loop, run in a new thread with context propagation - thread = ContextThread( - target=run_async_in_thread, - name=f"PostChatProcessing-{user_id}", - # Set as a daemon thread to avoid blocking program exit - daemon=True, - ) - thread.start() diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index ec431c253..671190e6f 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -47,13 +47,14 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: graph_db_backend_map = { "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), "neo4j": APIConfig.get_neo4j_config(user_id=user_id), - "nebular": APIConfig.get_nebular_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), "postgres": APIConfig.get_postgres_config(user_id=user_id), } # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars - graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower() + graph_db_backend = os.getenv( + "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community") + ).lower() return GraphDBConfigFactory.model_validate( { "backend": graph_db_backend, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py b/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py index 3498f596a..4d122ca4e 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py @@ -9,6 +9,7 @@ from memos.memories.textual.tree_text_memory.retrieve.internet_retriever import ( InternetGoogleRetriever, ) +from memos.memories.textual.tree_text_memory.retrieve.tavilysearch import InternetTavilyRetriever from memos.memories.textual.tree_text_memory.retrieve.xinyusearch import XinyuSearchRetriever from memos.memos_tools.singleton import singleton_factory @@ -21,6 +22,7 @@ class InternetRetrieverFactory: "bing": InternetGoogleRetriever, # TODO: Implement BingRetriever "xinyu": XinyuSearchRetriever, "bocha": BochaAISearchRetriever, + "tavily": InternetTavilyRetriever, } @classmethod @@ -81,6 +83,14 @@ def from_config( reader=MemReaderFactory.from_config(config.reader), max_results=config.max_results, ) + elif backend == "tavily": + return retriever_class( + api_key=config.api_key, + embedder=embedder, + max_results=config.max_results, + search_depth=config.search_depth, + include_answer=config.include_answer, + ) else: raise ValueError(f"Unsupported backend: {backend}") diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/tavilysearch.py b/src/memos/memories/textual/tree_text_memory/retrieve/tavilysearch.py new file mode 100644 index 000000000..fd7180b01 --- /dev/null +++ b/src/memos/memories/textual/tree_text_memory/retrieve/tavilysearch.py @@ -0,0 +1,327 @@ +"""Tavily Search API retriever for tree text memory.""" + +from concurrent.futures import as_completed +from datetime import datetime +from typing import Any + +from memos.context.context import ContextThreadPoolExecutor +from memos.dependency import require_python_package +from memos.embedders.factory import OllamaEmbedder +from memos.log import get_logger +from memos.mem_reader.read_multi_modal import detect_lang +from memos.memories.textual.item import ( + SearchedTreeNodeTextualMemoryMetadata, + SourceMessage, + TextualMemoryItem, +) + + +logger = get_logger(__name__) + + +class InternetTavilyRetriever: + """Tavily retriever that converts search results into TextualMemoryItem objects""" + + @require_python_package( + import_name="tavily", + install_command="pip install tavily-python", + install_link="https://github.com/tavily-ai/tavily-python", + ) + def __init__( + self, + api_key: str, + embedder: OllamaEmbedder, + max_results: int = 10, + search_depth: str = "basic", + include_answer: bool = False, + ): + """ + Initialize Tavily Search retriever. + + Args: + api_key: Tavily API key + embedder: Embedder instance for generating embeddings + max_results: Maximum number of search results to retrieve + search_depth: Search depth ('basic' or 'advanced') + include_answer: Whether to include an AI-generated answer + """ + from tavily import TavilyClient + + self.client = TavilyClient(api_key=api_key) + self.embedder = embedder + self.max_results = max_results + self.search_depth = search_depth + self.include_answer = include_answer + + import jieba.analyse + + self.zh_fast_keywords_extractor = jieba.analyse.TextRank() + + def _extract_tags(self, title: str, content: str, summary: str, parsed_goal=None) -> list[str]: + """ + Extract tags from title, content and summary. + + Args: + title: Article title + content: Article content + summary: Article summary + parsed_goal: Parsed task goal (optional) + + Returns: + List of extracted tags + """ + tags = [] + + tags.append("tavily_search") + tags.append("news") + + text = f"{title} {content} {summary}".lower() + + keywords = { + "economy": [ + "economy", + "GDP", + "growth", + "production", + "industry", + "investment", + "consumption", + "market", + "trade", + "finance", + ], + "politics": [ + "politics", + "government", + "policy", + "meeting", + "leader", + "election", + "parliament", + "ministry", + ], + "technology": [ + "technology", + "tech", + "innovation", + "digital", + "internet", + "AI", + "artificial intelligence", + "software", + "hardware", + ], + "sports": [ + "sports", + "game", + "athlete", + "olympic", + "championship", + "tournament", + "team", + "player", + ], + "culture": [ + "culture", + "education", + "art", + "history", + "literature", + "music", + "film", + "museum", + ], + "health": [ + "health", + "medical", + "pandemic", + "hospital", + "doctor", + "medicine", + "disease", + "treatment", + ], + "environment": [ + "environment", + "ecology", + "pollution", + "green", + "climate", + "sustainability", + "renewable", + ], + } + + for category, words in keywords.items(): + if any(word in text for word in words): + tags.append(category) + + if parsed_goal and hasattr(parsed_goal, "tags"): + tags.extend(parsed_goal.tags) + + return list(set(tags))[:15] + + def retrieve_from_internet( + self, query: str, top_k: int = 10, parsed_goal=None, info=None, mode="fast" + ) -> list[TextualMemoryItem]: + """ + Retrieve results from the internet using Tavily Search API. + + Args: + query: Search query + top_k: Number of results to retrieve + parsed_goal: Parsed task goal (optional) + info (dict): Metadata for memory consumption tracking + mode: Retrieval mode ('fast' or other) + + Returns: + List of TextualMemoryItem + """ + try: + response = self.client.search( + query=query, + max_results=min(top_k, self.max_results), + search_depth=self.search_depth, + include_answer=self.include_answer, + ) + search_results = response.get("results", []) + except Exception: + import traceback + + logger.error(f"Tavily search error: {traceback.format_exc()}") + return [] + + return self._convert_to_mem_items(search_results, query, parsed_goal, info, mode=mode) + + def _convert_to_mem_items( + self, search_results: list[dict], query: str, parsed_goal=None, info=None, mode="fast" + ): + """Convert Tavily search results into TextualMemoryItem objects.""" + memory_items = [] + if not info: + info = {"user_id": "", "session_id": ""} + + with ContextThreadPoolExecutor(max_workers=8) as executor: + futures = [ + executor.submit(self._process_result, r, query, parsed_goal, info, mode=mode) + for r in search_results + ] + for future in as_completed(futures): + try: + memory_items.extend(future.result()) + except Exception as e: + logger.error(f"Error processing Tavily search result: {e}") + + unique_memory_items = {item.memory: item for item in memory_items} + return list(unique_memory_items.values()) + + def _process_result( + self, result: dict, query: str, parsed_goal: str, info: dict[str, Any], mode="fast" + ) -> list[TextualMemoryItem]: + """Process one Tavily search result into TextualMemoryItem.""" + title = result.get("title", "") + content = result.get("content", "") + url = result.get("url", "") + publish_time = result.get("published_date", "") + + if publish_time: + try: + publish_time = datetime.fromisoformat(publish_time.replace("Z", "+00:00")).strftime( + "%Y-%m-%d" + ) + except Exception: + publish_time = datetime.now().strftime("%Y-%m-%d") + else: + publish_time = datetime.now().strftime("%Y-%m-%d") + + summary = content[:500] if content else "" + + if mode == "fast": + info_ = info.copy() + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + lang = detect_lang(summary) + tags = ( + self.zh_fast_keywords_extractor.textrank(summary, topK=3)[:3] + if lang == "zh" + else self._extract_tags(title, content, summary)[:3] + ) + + return [ + TextualMemoryItem( + memory=( + f"[Outer internet view] Title: {title}\nNewsTime:" + f" {publish_time}\nSummary:" + f" {summary}\n" + ), + metadata=SearchedTreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type="OuterMemory", + status="activated", + type="fact", + source="web", + sources=[SourceMessage(type="web", url=url)] if url else [], + visibility="public", + info=info_, + background="", + confidence=0.99, + usage=[], + tags=tags, + key=title, + embedding=self.embedder.embed([content])[0], + internet_info={ + "title": title, + "url": url, + "site_name": "", + "site_icon": None, + "summary": summary, + }, + ), + ) + ] + else: + info_ = info.copy() + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + lang = detect_lang(summary) + tags = ( + self.zh_fast_keywords_extractor.textrank(summary, topK=3)[:3] + if lang == "zh" + else self._extract_tags(title, content, summary)[:3] + ) + + return [ + TextualMemoryItem( + memory=( + f"[Outer internet view] Title: {title}\nNewsTime:" + f" {publish_time}\nSummary:" + f" {summary}\n" + f"Content: {content}" + ), + metadata=SearchedTreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type="OuterMemory", + status="activated", + type="fact", + source="web", + sources=[SourceMessage(type="web", url=url)] if url else [], + visibility="public", + info=info_, + background="", + confidence=0.99, + usage=[], + tags=tags, + key=title, + embedding=self.embedder.embed([content])[0], + internet_info={ + "title": title, + "url": url, + "site_name": "", + "site_icon": None, + "summary": summary, + }, + ), + ) + ] diff --git a/tests/api/test_memory_handler_delete.py b/tests/api/test_memory_handler_delete.py new file mode 100644 index 000000000..7d60f946b --- /dev/null +++ b/tests/api/test_memory_handler_delete.py @@ -0,0 +1,142 @@ +from unittest.mock import Mock + +from memos.api.handlers.memory_handler import handle_delete_memories +from memos.api.product_models import DeleteMemoryRequest + + +def _build_naive_mem_cube() -> Mock: + naive_mem_cube = Mock() + naive_mem_cube.text_mem = Mock() + naive_mem_cube.pref_mem = Mock() + return naive_mem_cube + + +def test_delete_memories_quick_by_user_id(): + naive_mem_cube = _build_naive_mem_cube() + req = DeleteMemoryRequest(user_id="u_1") + + resp = handle_delete_memories(req, naive_mem_cube) + + assert resp.data["status"] == "success" + naive_mem_cube.text_mem.delete_by_filter.assert_called_once_with( + writable_cube_ids=None, + filter={"and": [{"user_id": "u_1"}]}, + ) + naive_mem_cube.pref_mem.delete_by_filter.assert_called_once_with( + filter={"and": [{"user_id": "u_1"}]} + ) + + +def test_delete_memories_quick_by_conversation_alias(): + naive_mem_cube = _build_naive_mem_cube() + req = DeleteMemoryRequest(conversation_id="conv_1") + + assert req.session_id == "conv_1" + + resp = handle_delete_memories(req, naive_mem_cube) + + assert resp.data["status"] == "success" + naive_mem_cube.text_mem.delete_by_filter.assert_called_once_with( + writable_cube_ids=None, + filter={"and": [{"session_id": "conv_1"}]}, + ) + naive_mem_cube.pref_mem.delete_by_filter.assert_called_once_with( + filter={"and": [{"session_id": "conv_1"}]} + ) + + +def test_delete_memories_filter_and_quick_conditions(): + naive_mem_cube = _build_naive_mem_cube() + req = DeleteMemoryRequest( + filter={"and": [{"memory_type": "WorkingMemory"}]}, + user_id="u_1", + session_id="s_1", + ) + + resp = handle_delete_memories(req, naive_mem_cube) + + assert resp.data["status"] == "success" + naive_mem_cube.text_mem.delete_by_filter.assert_called_once_with( + writable_cube_ids=None, + filter={ + "and": [ + {"memory_type": "WorkingMemory"}, + {"user_id": "u_1", "session_id": "s_1"}, + ] + }, + ) + naive_mem_cube.pref_mem.delete_by_filter.assert_called_once_with( + filter={ + "and": [ + {"memory_type": "WorkingMemory"}, + {"user_id": "u_1", "session_id": "s_1"}, + ] + } + ) + + +def test_delete_memories_filter_or_with_distribution(): + naive_mem_cube = _build_naive_mem_cube() + req = DeleteMemoryRequest( + filter={"or": [{"memory_type": "WorkingMemory"}, {"memory_type": "UserMemory"}]}, + user_id="u_1", + ) + + resp = handle_delete_memories(req, naive_mem_cube) + + assert resp.data["status"] == "success" + naive_mem_cube.text_mem.delete_by_filter.assert_called_once_with( + writable_cube_ids=None, + filter={ + "or": [ + {"memory_type": "WorkingMemory", "user_id": "u_1"}, + {"memory_type": "UserMemory", "user_id": "u_1"}, + ] + }, + ) + naive_mem_cube.pref_mem.delete_by_filter.assert_called_once_with( + filter={ + "or": [ + {"memory_type": "WorkingMemory", "user_id": "u_1"}, + {"memory_type": "UserMemory", "user_id": "u_1"}, + ] + } + ) + + +def test_delete_memories_reject_multiple_modes(): + naive_mem_cube = _build_naive_mem_cube() + req = DeleteMemoryRequest(memory_ids=["m_1"], user_id="u_1") + + resp = handle_delete_memories(req, naive_mem_cube) + + assert resp.data["status"] == "failure" + assert "Exactly one delete mode must be provided" in resp.message + naive_mem_cube.text_mem.delete_by_filter.assert_not_called() + naive_mem_cube.text_mem.delete_by_memory_ids.assert_not_called() + + +def test_delete_memories_reject_empty_filter(): + naive_mem_cube = _build_naive_mem_cube() + req = DeleteMemoryRequest(filter={}) + + resp = handle_delete_memories(req, naive_mem_cube) + + assert resp.data["status"] == "failure" + assert "filter cannot be empty" in resp.message + naive_mem_cube.text_mem.delete_by_filter.assert_not_called() + naive_mem_cube.pref_mem.delete_by_filter.assert_not_called() + + +def test_delete_memories_with_pref_mem_disabled(): + naive_mem_cube = _build_naive_mem_cube() + naive_mem_cube.pref_mem = None + req = DeleteMemoryRequest(user_id="u_1") + + resp = handle_delete_memories(req, naive_mem_cube) + + assert resp.data["status"] == "success" + naive_mem_cube.text_mem.delete_by_filter.assert_called_once_with( + writable_cube_ids=None, + filter={"and": [{"user_id": "u_1"}]}, + ) diff --git a/tests/api/test_product_router.py b/tests/api/test_product_router.py deleted file mode 100644 index 857b290c5..000000000 --- a/tests/api/test_product_router.py +++ /dev/null @@ -1,422 +0,0 @@ -""" -Unit tests for product_router input/output format validation. - -This module tests that the product_router endpoints correctly validate -input request formats and return properly formatted responses. -""" - -from unittest.mock import Mock, patch - -import pytest - -from fastapi.testclient import TestClient - -# Patch the MOS_PRODUCT_INSTANCE directly after import -# Patch MOS_PRODUCT_INSTANCE and MOSProduct so we can test the FastAPI router -# without initializing the full MemOS product stack. -import memos.api.routers.product_router as pr_module - - -_mock_mos_instance = Mock() -pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance -pr_module.get_mos_product_instance = lambda: _mock_mos_instance -with patch("memos.mem_os.product.MOSProduct", return_value=_mock_mos_instance): - from memos.api import product_api - - -@pytest.fixture(scope="module") -def mock_mos_product_instance(): - """Mock get_mos_product_instance for all tests.""" - # Ensure the mock is set - pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance - pr_module.get_mos_product_instance = lambda: _mock_mos_instance - yield product_api.app, _mock_mos_instance - - -@pytest.fixture -def client(mock_mos_product_instance): - """Create test client for product_api.""" - app, _ = mock_mos_product_instance - return TestClient(app) - - -@pytest.fixture -def mock_mos_product(mock_mos_product_instance): - """Get the mocked MOSProduct instance.""" - _, mock_instance = mock_mos_product_instance - # Ensure get_mos_product_instance returns this mock - import memos.api.routers.product_router as pr_module - - pr_module.get_mos_product_instance = lambda: mock_instance - pr_module.MOS_PRODUCT_INSTANCE = mock_instance - return mock_instance - - -@pytest.fixture(autouse=True) -def setup_mock_mos_product(mock_mos_product): - """Set up default return values for MOSProduct methods.""" - # Set up default return values for methods - mock_mos_product.search.return_value = {"text_mem": [], "act_mem": [], "para_mem": []} - mock_mos_product.add.return_value = None - mock_mos_product.chat.return_value = ("test response", []) - mock_mos_product.chat_with_references.return_value = iter( - ['data: {"type": "content", "data": "test"}\n\n'] - ) - # Ensure get_all and get_subgraph return proper list format (MemoryResponse expects list) - default_memory_result = [{"cube_id": "test_cube", "memories": []}] - mock_mos_product.get_all.return_value = default_memory_result - mock_mos_product.get_subgraph.return_value = default_memory_result - mock_mos_product.get_suggestion_query.return_value = ["suggestion1", "suggestion2"] - # Ensure get_mos_product_instance returns the mock - import memos.api.routers.product_router as pr_module - - pr_module.get_mos_product_instance = lambda: mock_mos_product - - -class TestProductRouterSearch: - """Test /search endpoint input/output format.""" - - def test_search_valid_input_output(self, mock_mos_product, client): - """Test search endpoint with valid input returns correct output format.""" - request_data = { - "user_id": "test_user", - "query": "test query", - "mem_cube_id": "test_cube", - "top_k": 10, - } - - response = client.post("/product/search", json=request_data) - - assert response.status_code == 200 - data = response.json() - - # Validate response structure - assert "code" in data - assert "message" in data - assert "data" in data - assert data["code"] == 200 - assert isinstance(data["data"], dict) - - # Verify MOSProduct.search was called with correct parameters - mock_mos_product.search.assert_called_once() - call_kwargs = mock_mos_product.search.call_args[1] - assert call_kwargs["user_id"] == "test_user" - assert call_kwargs["query"] == "test query" - - def test_search_invalid_input_missing_user_id(self, mock_mos_product, client): - """Test search endpoint with missing required field.""" - request_data = { - "query": "test query", - } - - response = client.post("/product/search", json=request_data) - - # Should return validation error - assert response.status_code == 422 - - def test_search_response_format(self, mock_mos_product, client): - """Test search endpoint returns SearchResponse format.""" - mock_mos_product.search.return_value = { - "text_mem": [{"cube_id": "test_cube", "memories": []}], - "act_mem": [], - "para_mem": [], - } - - request_data = { - "user_id": "test_user", - "query": "test query", - } - - response = client.post("/product/search", json=request_data) - - assert response.status_code == 200 - data = response.json() - assert data["message"] == "Search completed successfully" - assert isinstance(data["data"], dict) - assert "text_mem" in data["data"] - - -class TestProductRouterAdd: - """Test /add endpoint input/output format.""" - - def test_add_valid_input_output(self, mock_mos_product, client): - """Test add endpoint with valid input returns correct output format.""" - request_data = { - "user_id": "test_user", - "memory_content": "test memory content", - "mem_cube_id": "test_cube", - } - - response = client.post("/product/add", json=request_data) - - assert response.status_code == 200 - data = response.json() - - # Validate response structure - assert "code" in data - assert "message" in data - assert "data" in data - assert data["code"] == 200 - assert data["data"] is None # SimpleResponse has None data - - # Verify MOSProduct.add was called with correct parameters - mock_mos_product.add.assert_called_once() - call_kwargs = mock_mos_product.add.call_args[1] - assert call_kwargs["user_id"] == "test_user" - assert call_kwargs["memory_content"] == "test memory content" - - def test_add_invalid_input_missing_user_id(self, mock_mos_product, client): - """Test add endpoint with missing required field.""" - request_data = { - "memory_content": "test memory content", - } - - response = client.post("/product/add", json=request_data) - - # Should return validation error - assert response.status_code == 422 - - def test_add_response_format(self, mock_mos_product, client): - """Test add endpoint returns SimpleResponse format.""" - request_data = { - "user_id": "test_user", - "memory_content": "test memory content", - } - - response = client.post("/product/add", json=request_data) - - assert response.status_code == 200 - data = response.json() - assert data["message"] == "Memory created successfully" - assert data["data"] is None - - -class TestProductRouterChatComplete: - """Test /chat/complete endpoint input/output format.""" - - def test_chat_complete_valid_input_output(self, mock_mos_product, client): - """Test chat/complete endpoint with valid input returns correct output format.""" - request_data = { - "user_id": "test_user", - "query": "test query", - "mem_cube_id": "test_cube", - } - - response = client.post("/product/chat/complete", json=request_data) - - assert response.status_code == 200 - data = response.json() - - # Validate response structure - assert "message" in data - assert "data" in data - assert isinstance(data["data"], dict) - assert "response" in data["data"] - assert "references" in data["data"] - - # Verify MOSProduct.chat was called with correct parameters - mock_mos_product.chat.assert_called_once() - call_kwargs = mock_mos_product.chat.call_args[1] - assert call_kwargs["user_id"] == "test_user" - assert call_kwargs["query"] == "test query" - - def test_chat_complete_invalid_input_missing_user_id(self, mock_mos_product, client): - """Test chat/complete endpoint with missing required field.""" - request_data = { - "query": "test query", - } - - response = client.post("/product/chat/complete", json=request_data) - - # Should return validation error - assert response.status_code == 422 - - def test_chat_complete_response_format(self, mock_mos_product, client): - """Test chat/complete endpoint returns correct format.""" - mock_mos_product.chat.return_value = ("test response", [{"id": "ref1"}]) - - request_data = { - "user_id": "test_user", - "query": "test query", - } - - response = client.post("/product/chat/complete", json=request_data) - - assert response.status_code == 200 - data = response.json() - assert data["message"] == "Chat completed successfully" - assert isinstance(data["data"]["response"], str) - assert isinstance(data["data"]["references"], list) - - -class TestProductRouterChat: - """Test /chat endpoint input/output format (SSE stream).""" - - def test_chat_valid_input_output(self, mock_mos_product, client): - """Test chat endpoint with valid input returns SSE stream.""" - request_data = { - "user_id": "test_user", - "query": "test query", - "mem_cube_id": "test_cube", - } - - response = client.post("/product/chat", json=request_data) - - assert response.status_code == 200 - assert "text/event-stream" in response.headers["content-type"] - - # Verify MOSProduct.chat_with_references was called - mock_mos_product.chat_with_references.assert_called_once() - call_kwargs = mock_mos_product.chat_with_references.call_args[1] - assert call_kwargs["user_id"] == "test_user" - assert call_kwargs["query"] == "test query" - - def test_chat_invalid_input_missing_user_id(self, mock_mos_product, client): - """Test chat endpoint with missing required field.""" - request_data = { - "query": "test query", - } - - response = client.post("/product/chat", json=request_data) - - # Should return validation error - assert response.status_code == 422 - - -class TestProductRouterSuggestions: - """Test /suggestions endpoint input/output format.""" - - def test_suggestions_valid_input_output(self, mock_mos_product, client): - """Test suggestions endpoint with valid input returns correct output format.""" - request_data = { - "user_id": "test_user", - "mem_cube_id": "test_cube", - "language": "zh", - } - - response = client.post("/product/suggestions", json=request_data) - - assert response.status_code == 200 - data = response.json() - - # Validate response structure - assert "code" in data - assert "message" in data - assert "data" in data - assert data["code"] == 200 - assert isinstance(data["data"], dict) - assert "query" in data["data"] - - # Verify MOSProduct.get_suggestion_query was called - mock_mos_product.get_suggestion_query.assert_called_once() - call_kwargs = mock_mos_product.get_suggestion_query.call_args[1] - assert call_kwargs["user_id"] == "test_user" - - def test_suggestions_invalid_input_missing_user_id(self, mock_mos_product, client): - """Test suggestions endpoint with missing required field.""" - request_data = { - "mem_cube_id": "test_cube", - } - - response = client.post("/product/suggestions", json=request_data) - - # Should return validation error - assert response.status_code == 422 - - def test_suggestions_response_format(self, mock_mos_product, client): - """Test suggestions endpoint returns SuggestionResponse format.""" - mock_mos_product.get_suggestion_query.return_value = [ - "suggestion1", - "suggestion2", - "suggestion3", - ] - - request_data = { - "user_id": "test_user", - "mem_cube_id": "test_cube", - "language": "en", - } - - response = client.post("/product/suggestions", json=request_data) - - assert response.status_code == 200 - data = response.json() - assert data["message"] == "Suggestions retrieved successfully" - assert isinstance(data["data"], dict) - assert isinstance(data["data"]["query"], list) - - -class TestProductRouterGetAll: - """Test /get_all endpoint input/output format.""" - - def test_get_all_valid_input_output(self, mock_mos_product, client): - """Test get_all endpoint with valid input returns correct output format.""" - request_data = { - "user_id": "test_user", - "memory_type": "text_mem", - } - - response = client.post("/product/get_all", json=request_data) - - assert response.status_code == 200 - data = response.json() - - # Validate response structure - assert "code" in data - assert "message" in data - assert "data" in data - assert data["code"] == 200 - assert isinstance(data["data"], list) - - # Verify MOSProduct.get_all was called - mock_mos_product.get_all.assert_called_once() - call_kwargs = mock_mos_product.get_all.call_args[1] - assert call_kwargs["user_id"] == "test_user" - assert call_kwargs["memory_type"] == "text_mem" - - def test_get_all_with_search_query(self, mock_mos_product, client): - """Test get_all endpoint with search_query uses get_subgraph.""" - # Reset mock call counts - mock_mos_product.get_all.reset_mock() - mock_mos_product.get_subgraph.reset_mock() - - request_data = { - "user_id": "test_user", - "memory_type": "text_mem", - "search_query": "test query", - } - - response = client.post("/product/get_all", json=request_data) - - assert response.status_code == 200 - # Verify get_subgraph was called instead of get_all - mock_mos_product.get_subgraph.assert_called_once() - mock_mos_product.get_all.assert_not_called() - - def test_get_all_invalid_input_missing_user_id(self, mock_mos_product, client): - """Test get_all endpoint with missing required field.""" - request_data = { - "memory_type": "text_mem", - } - - response = client.post("/product/get_all", json=request_data) - - # Should return validation error - assert response.status_code == 422 - - def test_get_all_response_format(self, mock_mos_product, client): - """Test get_all endpoint returns MemoryResponse format.""" - mock_mos_product.get_all.return_value = [{"cube_id": "test_cube", "memories": []}] - - request_data = { - "user_id": "test_user", - "memory_type": "text_mem", - } - - response = client.post("/product/get_all", json=request_data) - - assert response.status_code == 200 - data = response.json() - assert data["message"] == "Memories retrieved successfully" - assert isinstance(data["data"], list) - assert len(data["data"]) > 0 diff --git a/tests/api/test_start_api.py b/tests/api/test_start_api.py deleted file mode 100644 index e1ffcd74b..000000000 --- a/tests/api/test_start_api.py +++ /dev/null @@ -1,401 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest - -from fastapi.testclient import TestClient - -from memos.api.start_api import app -from memos.mem_user.user_manager import UserRole - - -client = TestClient(app) - -# Mock data -MOCK_MESSAGE = {"role": "user", "content": "test message"} -MOCK_MEMORY_CREATE = { - "messages": [MOCK_MESSAGE], - "mem_cube_id": "test_cube", - "user_id": "test_user", -} -MOCK_MEMORY_CONTENT = { - "memory_content": "test memory content", - "mem_cube_id": "test_cube", - "user_id": "test_user", -} -MOCK_DOC_PATH = {"doc_path": "/path/to/doc", "mem_cube_id": "test_cube", "user_id": "test_user"} -MOCK_SEARCH_REQUEST = { - "query": "test query", - "user_id": "test_user", - "install_cube_ids": ["test_cube"], -} -MOCK_MEMCUBE_REGISTER = { - "mem_cube_name_or_path": "test_cube_path", - "mem_cube_id": "test_cube", - "user_id": "test_user", -} -MOCK_CHAT_REQUEST = {"query": "test chat query", "user_id": "test_user"} -MOCK_USER_CREATE = {"user_id": "test_user", "user_name": "Test User", "role": "USER"} -MOCK_CUBE_SHARE = {"target_user_id": "target_user"} -MOCK_CONFIG = { - "user_id": "test_user", - "session_id": "test_session", - "enable_textual_memory": True, - "enable_activation_memory": False, - "top_k": 5, - "chat_model": { - "backend": "openai", - "config": { - "model_name_or_path": "gpt-3.5-turbo", - "api_key": "test_key", - "temperature": 0.7, - "api_base": "https://api.openai.com/v1", - }, - }, -} - - -@pytest.fixture -def mock_mos(): - """Mock MOS instance for testing.""" - with patch("memos.api.start_api.get_mos_instance") as mock_get_mos: - # Create a mock MOS instance - mock_instance = Mock() - - # Set up default return values for methods - mock_instance.search.return_value = {"text_mem": [], "act_mem": [], "para_mem": []} - mock_instance.get_all.return_value = {"text_mem": [], "act_mem": [], "para_mem": []} - mock_instance.get.return_value = {"memory": "test memory"} - mock_instance.chat.return_value = "test response" - mock_instance.list_users.return_value = [] - mock_instance.get_user_info.return_value = { - "user_id": "test_user", - "user_name": "Test User", - "role": "user", - "accessible_cubes": [], - } - mock_instance.create_user.return_value = "test_user" - mock_instance.share_cube_with_user.return_value = True - - # Configure the mock to return our mock instance - mock_get_mos.return_value = mock_instance - - yield mock_instance - - -def test_configure_error(mock_mos): - """Test configuration endpoint with error.""" - with patch("memos.api.start_api.MOS_INSTANCE", None): - response = client.post("/configure", json={}) - assert response.status_code == 422 # FastAPI validation error - - -def test_create_user(mock_mos): - """Test user creation endpoint.""" - response = client.post("/users", json=MOCK_USER_CREATE) - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "User created successfully", - "data": {"user_id": "test_user"}, - } - mock_mos.create_user.assert_called_once_with( - user_id="test_user", role=UserRole.USER, user_name="Test User" - ) - - -def test_create_user_validation_error(mock_mos): - """Test user creation with validation error.""" - mock_mos.create_user.side_effect = ValueError("Invalid user data") - response = client.post("/users", json=MOCK_USER_CREATE) - assert response.status_code == 400 - assert "Invalid user data" in response.json()["message"] - - -def test_list_users(mock_mos): - """Test list users endpoint.""" - # Set up mock to return the expected data structure - mock_users = [ - { - "user_id": "test_user", - "user_name": "Test User", - "role": "user", - "created_at": "2023-01-01T00:00:00", - "is_active": True, - } - ] - mock_mos.list_users.return_value = mock_users - - response = client.get("/users") - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "Users retrieved successfully", - "data": mock_users, - } - mock_mos.list_users.assert_called_once() - - -def test_get_user_info(mock_mos): - """Test get user info endpoint.""" - # Set up mock to return the expected data structure - mock_user_info = { - "user_id": "test_user", - "user_name": "Test User", - "role": "user", - "created_at": "2023-01-01T00:00:00", - "accessible_cubes": [], - } - mock_mos.get_user_info.return_value = mock_user_info - - response = client.get("/users/me") - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "User info retrieved successfully", - "data": mock_user_info, - } - mock_mos.get_user_info.assert_called_once() - - -def test_register_mem_cube(mock_mos): - """Test MemCube registration endpoint.""" - response = client.post("/mem_cubes", json=MOCK_MEMCUBE_REGISTER) - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "MemCube registered successfully", - "data": None, - } - mock_mos.register_mem_cube.assert_called_once_with( - mem_cube_name_or_path="test_cube_path", mem_cube_id="test_cube", user_id="test_user" - ) - - -def test_register_mem_cube_validation_error(mock_mos): - """Test MemCube registration with validation error.""" - mock_mos.register_mem_cube.side_effect = ValueError("Invalid MemCube") - response = client.post("/mem_cubes", json=MOCK_MEMCUBE_REGISTER) - assert response.status_code == 400 - assert "Invalid MemCube" in response.json()["message"] - - -def test_unregister_mem_cube(mock_mos): - """Test MemCube unregistration endpoint.""" - response = client.delete("/mem_cubes/test_cube?user_id=test_user") - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "MemCube unregistered successfully", - "data": None, - } - mock_mos.unregister_mem_cube.assert_called_once_with( - mem_cube_id="test_cube", user_id="test_user" - ) - - -def test_unregister_nonexistent_mem_cube(mock_mos): - """Test unregistering a non-existent MemCube.""" - mock_mos.unregister_mem_cube.side_effect = ValueError("MemCube not found") - response = client.delete("/mem_cubes/nonexistent_cube") - assert response.status_code == 400 - assert "MemCube not found" in response.json()["message"] - - -def test_share_cube(mock_mos): - """Test cube sharing endpoint.""" - response = client.post("/mem_cubes/test_cube/share", json=MOCK_CUBE_SHARE) - assert response.status_code == 200 - assert response.json() == {"code": 200, "message": "Cube shared successfully", "data": None} - mock_mos.share_cube_with_user.assert_called_once_with("test_cube", "target_user") - - -def test_share_cube_failure(mock_mos): - """Test cube sharing failure.""" - mock_mos.share_cube_with_user.return_value = False - response = client.post("/mem_cubes/test_cube/share", json=MOCK_CUBE_SHARE) - assert response.status_code == 400 - assert "Failed to share cube" in response.json()["message"] - - -@pytest.mark.parametrize( - "memory_create,expected_calls", - [ - (MOCK_MEMORY_CREATE, {"messages": [MOCK_MESSAGE]}), - (MOCK_MEMORY_CONTENT, {"memory_content": "test memory content"}), - (MOCK_DOC_PATH, {"doc_path": "/path/to/doc"}), - ], -) -def test_add_memory(mock_mos, memory_create, expected_calls): - """Test adding memories with different types of content.""" - response = client.post("/memories", json=memory_create) - assert response.status_code == 200 - assert response.json() == {"code": 200, "message": "Memories added successfully", "data": None} - mock_mos.add.assert_called_once() - - -def test_add_memory_validation_error(mock_mos): - """Test adding memory with validation error.""" - response = client.post("/memories", json={}) - assert response.status_code == 400 - assert "must be provided" in response.json()["message"] - - -def test_get_all_memories(mock_mos): - """Test get all memories endpoint.""" - mock_results = { - "text_mem": [{"cube_id": "test_cube", "memories": []}], - "act_mem": [], - "para_mem": [], - } - mock_mos.get_all.return_value = mock_results - - response = client.get("/memories") - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "Memories retrieved successfully", - "data": mock_results, - } - mock_mos.get_all.assert_called_once_with(mem_cube_id=None, user_id=None) - - -def test_get_memory(mock_mos): - """Test get specific memory endpoint.""" - mock_memory = {"memory": "test memory content"} - mock_mos.get.return_value = mock_memory - - response = client.get("/memories/test_cube/test_memory") - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "Memory retrieved successfully", - "data": mock_memory, - } - mock_mos.get.assert_called_once_with( - mem_cube_id="test_cube", memory_id="test_memory", user_id=None - ) - - -def test_get_nonexistent_memory(mock_mos): - """Test getting a non-existent memory.""" - mock_mos.get.side_effect = ValueError("Memory not found") - response = client.get("/memories/test_cube/nonexistent_memory") - assert response.status_code == 400 - assert "Memory not found" in response.json()["message"] - - -def test_search_memories(mock_mos): - """Test search memories endpoint.""" - # Mock the search method to return a proper result structure - mock_results = {"text_mem": [], "act_mem": [], "para_mem": []} - mock_mos.search.return_value = mock_results - - # Ensure the search request has all required fields - search_request = { - "query": "test query", - "user_id": "test_user", - "install_cube_ids": ["test_cube"], - } - - response = client.post("/search", json=search_request) - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "Search completed successfully", - "data": mock_results, - } - mock_mos.search.assert_called_once_with( - query="test query", user_id="test_user", install_cube_ids=["test_cube"] - ) - - -def test_update_memory(mock_mos): - """Test updating a memory endpoint.""" - update_data = {"content": "updated content"} - response = client.put("/memories/test_cube/test_memory?user_id=test_user", json=update_data) - assert response.status_code == 200 - assert response.json() == {"code": 200, "message": "Memory updated successfully", "data": None} - mock_mos.update.assert_called_once_with( - mem_cube_id="test_cube", - memory_id="test_memory", - text_memory_item=update_data, - user_id="test_user", - ) - - -def test_update_nonexistent_memory(mock_mos): - """Test updating a non-existent memory.""" - mock_mos.update.side_effect = ValueError("Memory not found") - response = client.put("/memories/test_cube/nonexistent_memory", json={}) - assert response.status_code == 400 - assert "Memory not found" in response.json()["message"] - - -def test_delete_memory(mock_mos): - """Test deleting a memory endpoint.""" - response = client.delete("/memories/test_cube/test_memory?user_id=test_user") - assert response.status_code == 200 - assert response.json() == {"code": 200, "message": "Memory deleted successfully", "data": None} - mock_mos.delete.assert_called_once_with( - mem_cube_id="test_cube", memory_id="test_memory", user_id="test_user" - ) - - -def test_delete_nonexistent_memory(mock_mos): - """Test deleting a non-existent memory.""" - mock_mos.delete.side_effect = ValueError("Memory not found") - response = client.delete("/memories/test_cube/nonexistent_memory") - assert response.status_code == 400 - assert "Memory not found" in response.json()["message"] - - -def test_delete_all_memories(mock_mos): - """Test deleting all memories endpoint.""" - response = client.delete("/memories/test_cube?user_id=test_user") - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "All memories deleted successfully", - "data": None, - } - mock_mos.delete_all.assert_called_once_with(mem_cube_id="test_cube", user_id="test_user") - - -def test_delete_all_nonexistent_memories(mock_mos): - """Test deleting all memories from non-existent MemCube.""" - mock_mos.delete_all.side_effect = ValueError("MemCube not found") - response = client.delete("/memories/nonexistent_cube") - assert response.status_code == 400 - assert "MemCube not found" in response.json()["message"] - - -def test_chat(mock_mos): - """Test chat endpoint.""" - response = client.post("/chat", json=MOCK_CHAT_REQUEST) - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "Chat response generated", - "data": "test response", - } - mock_mos.chat.assert_called_once_with(query="test chat query", user_id="test_user") - - -def test_chat_without_user_id(mock_mos): - """Test chat endpoint without user_id.""" - chat_request = {"query": "test chat query"} - response = client.post("/chat", json=chat_request) - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "Chat response generated", - "data": "test response", - } - mock_mos.chat.assert_called_once_with(query="test chat query", user_id=None) - - -def test_home_redirect(): - """Test home endpoint redirects to docs.""" - response = client.get("/", follow_redirects=False) - assert response.status_code == 307 - assert response.headers["location"] == "/docs" diff --git a/tests/configs/test_llm.py b/tests/configs/test_llm.py index f3d4549b5..d1507e389 100644 --- a/tests/configs/test_llm.py +++ b/tests/configs/test_llm.py @@ -2,6 +2,7 @@ BaseLLMConfig, HFLLMConfig, LLMConfigFactory, + MinimaxLLMConfig, OllamaLLMConfig, OpenAILLMConfig, ) @@ -145,6 +146,42 @@ def test_hf_llm_config(): check_config_instantiation_invalid(HFLLMConfig) +def test_minimax_llm_config(): + check_config_base_class( + MinimaxLLMConfig, + required_fields=["model_name_or_path", "api_key"], + optional_fields=[ + "temperature", + "max_tokens", + "top_p", + "top_k", + "api_base", + "remove_think_prefix", + "extra_body", + "default_headers", + "backup_client", + "backup_api_key", + "backup_api_base", + "backup_model_name_or_path", + "backup_headers", + ], + ) + + check_config_instantiation_valid( + MinimaxLLMConfig, + { + "model_name_or_path": "MiniMax-M2.7", + "api_key": "test-key", + "api_base": "https://api.minimax.io/v1", + "temperature": 0.7, + "max_tokens": 1024, + "top_p": 0.9, + }, + ) + + check_config_instantiation_invalid(MinimaxLLMConfig) + + def test_llm_config_factory(): check_config_factory_class( LLMConfigFactory, diff --git a/tests/graph_dbs/graph_dbs.py b/tests/graph_dbs/graph_dbs.py index 2cc35a0ad..c834f65a9 100644 --- a/tests/graph_dbs/graph_dbs.py +++ b/tests/graph_dbs/graph_dbs.py @@ -105,3 +105,26 @@ def test_get_memory_count(graph_db): session_mock.run.return_value.single.return_value = {"count": 42} count = graph_db.get_memory_count("WorkingMemory") assert count == 42 + + +def test_add_node_sanitizes_nested_metadata(graph_db): + session_mock = graph_db.driver.session.return_value.__enter__.return_value + node_id = str(uuid.uuid4()) + memory = "skill memory" + metadata = { + "memory_type": "SkillMemory", + "embedding": [0.1, 0.2, 0.3], + "tags": ["skill"], + "scripts": {"run.py": "print(1)"}, + "others": {"README.md": "# demo"}, + "info": {"nested": {"x": 1}, "arr_obj": [{"a": 1}]}, + } + + graph_db.add_node(node_id, memory, metadata) + + _, kwargs = session_mock.run.call_args + sanitized = kwargs["metadata"] + assert isinstance(sanitized["scripts"], str) + assert isinstance(sanitized["others"], str) + assert isinstance(sanitized["nested"], str) + assert sanitized["arr_obj"] == ['{"a": 1}'] diff --git a/tests/graph_dbs/test_neo4j_vector_search.py b/tests/graph_dbs/test_neo4j_vector_search.py new file mode 100644 index 000000000..3ed0b0587 --- /dev/null +++ b/tests/graph_dbs/test_neo4j_vector_search.py @@ -0,0 +1,425 @@ +""" +Tests for Neo4j vector search pre-filtering and related regressions. + +- Unit tests: verify query structure (pre-filter vs ANN paths) using mocks +- Integration tests: verify real search behavior with multi-user data (requires Neo4j 5.18+) + +The pre-filter approach (Neo4j 5.18+): + When WHERE filters are present (scope, status, user_name, etc.), the query uses + MATCH + WHERE to narrow candidates first, then vector.similarity.cosine() + computes similarity only on the filtered set. This avoids the post-filter + problem entirely โ€” no nodes are lost due to global top-k truncation. + + When no filters are present, the ANN vector index (db.index.vector.queryNodes) + is used for maximum efficiency. +""" + +import math +import os +import uuid + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from memos.configs.graph_db import Neo4jGraphDBConfig + + +# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +# Fixtures for unit tests (mocked Neo4j driver) +# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +@pytest.fixture +def shared_db_config(): + """Shared-database multi-tenant config (use_multi_db=False).""" + return Neo4jGraphDBConfig( + uri="bolt://localhost:7687", + user="neo4j", + password="test", + db_name="test_db", + auto_create=False, + use_multi_db=False, + user_name="default_user", + embedding_dimension=3, + ) + + +@pytest.fixture +def multi_db_config(): + """Multi-database config โ€” no user_name filter in queries.""" + return Neo4jGraphDBConfig( + uri="bolt://localhost:7687", + user="neo4j", + password="test", + db_name="test_db", + auto_create=False, + use_multi_db=True, + embedding_dimension=3, + ) + + +@pytest.fixture +def shared_neo4j_db(shared_db_config): + with patch("neo4j.GraphDatabase") as mock_gd: + mock_driver = MagicMock() + mock_gd.driver.return_value = mock_driver + from memos.graph_dbs.neo4j import Neo4jGraphDB + + db = Neo4jGraphDB(shared_db_config) + db.driver = mock_driver + yield db + + +@pytest.fixture +def multi_neo4j_db(multi_db_config): + with patch("neo4j.GraphDatabase") as mock_gd: + mock_driver = MagicMock() + mock_gd.driver.return_value = mock_driver + from memos.graph_dbs.neo4j import Neo4jGraphDB + + db = Neo4jGraphDB(multi_db_config) + db.driver = mock_driver + yield db + + +# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +# Unit tests: pre-filter vs ANN query paths +# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +class TestVectorSearchPreFilter: + """Verify pre-filter path uses MATCH + vector.similarity.cosine() + and ANN path uses db.index.vector.queryNodes.""" + + def test_prefilter_with_scope(self, shared_neo4j_db): + """With scope filter, query should use MATCH + cosine similarity, not queryNodes.""" + session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + shared_neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + scope="LongTermMemory", + ) + + query = session_mock.run.call_args[0][0] + assert "MATCH (node:Memory)" in query + assert "vector.similarity.cosine(node.embedding, $embedding)" in query + assert "queryNodes" not in query + + def test_prefilter_with_all_filters(self, shared_neo4j_db): + """With scope + status + user_name, all filters appear in WHERE before similarity.""" + session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + shared_neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=10, + scope="LongTermMemory", + status="activated", + user_name="some_user", + ) + + query = session_mock.run.call_args[0][0] + assert "MATCH (node:Memory)" in query + assert "node.memory_type = $scope" in query + assert "node.status = $status" in query + assert "node.user_name = $user_name" in query + assert "vector.similarity.cosine" in query + + def test_prefilter_includes_embedding_not_null(self, shared_neo4j_db): + """Pre-filter query should exclude nodes without embeddings.""" + session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + shared_neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + scope="LongTermMemory", + ) + + query = session_mock.run.call_args[0][0] + assert "node.embedding IS NOT NULL" in query + + def test_prefilter_has_order_by_and_limit(self, shared_neo4j_db): + """Pre-filter results should be ordered by score and limited.""" + session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + shared_neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + scope="LongTermMemory", + ) + + query = session_mock.run.call_args[0][0] + assert "ORDER BY score DESC" in query + assert "LIMIT $top_k" in query + params = session_mock.run.call_args[0][1] + assert params["top_k"] == 5 + + def test_ann_path_without_filters(self, multi_neo4j_db): + """Without any filter, query should use queryNodes ANN index.""" + session_mock = multi_neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + multi_neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + ) + + query = session_mock.run.call_args[0][0] + assert "queryNodes" in query + assert "$top_k" in query + assert "MATCH (node:Memory)" not in query + params = session_mock.run.call_args[0][1] + assert params["top_k"] == 5 + + def test_ann_path_no_redundant_params(self, multi_neo4j_db): + """ANN path should only have embedding and top_k, nothing extra.""" + session_mock = multi_neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + multi_neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + ) + + params = session_mock.run.call_args[0][1] + assert set(params.keys()) == {"embedding", "top_k"} + + +# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +# Unit tests: sources KeyError regression +# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +class TestSourcesKeyErrorRegression: + """Verify that missing/None 'sources' key doesn't cause KeyError.""" + + def test_add_node_without_sources_key(self, shared_neo4j_db): + session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value + + shared_neo4j_db.add_node( + id="test-node-1", + memory="test content", + metadata={ + "memory_type": "WorkingMemory", + "embedding": [0.1, 0.2, 0.3], + "created_at": datetime.now(timezone.utc).isoformat(), + "updated_at": datetime.now(timezone.utc).isoformat(), + }, + ) + + calls = session_mock.run.call_args_list + assert any("MERGE (n:Memory" in str(call) for call in calls) + + def test_add_node_with_empty_sources(self, shared_neo4j_db): + _session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value + + shared_neo4j_db.add_node( + id="test-node-2", + memory="test content", + metadata={ + "memory_type": "WorkingMemory", + "embedding": [0.1, 0.2, 0.3], + "sources": [], + "created_at": datetime.now(timezone.utc).isoformat(), + "updated_at": datetime.now(timezone.utc).isoformat(), + }, + ) + + def test_parse_node_without_sources_key(self, shared_neo4j_db): + result = shared_neo4j_db._parse_node( + { + "id": "node-1", + "memory": "hello", + "memory_type": "WorkingMemory", + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } + ) + assert result["id"] == "node-1" + assert result["memory"] == "hello" + + +# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +# Integration tests (require a running Neo4j 5.18+ with vector index) +# +# Activate by setting environment variables: +# NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD +# +# Run: +# pytest tests/graph_dbs/test_neo4j_vector_search.py -k Integration -v +# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def _neo4j_package_available(): + try: + import neo4j # noqa: F401 + + return True + except ImportError: + return False + + +_neo4j_configured = _neo4j_package_available() and all( + os.getenv(k) for k in ("NEO4J_URI", "NEO4J_USER", "NEO4J_PASSWORD") +) +_TEST_RUN_ID = uuid.uuid4().hex[:8] +_TARGET_USER = f"__test_target_{_TEST_RUN_ID}" +_OTHER_USER_PREFIX = f"__test_other_{_TEST_RUN_ID}" + + +def _make_unit_vector( + dim: int, dominant_axis: int, secondary_axis: int | None = None +) -> list[float]: + """ + Create a unit vector concentrated on one axis, optionally blended with a second. + + Used to control cosine similarity in tests: + - Two vectors on the same axis โ†’ cos_sim โ‰ˆ 1.0 + - Orthogonal axes โ†’ cos_sim โ‰ˆ 0.0 + - Blended โ†’ cos_sim โ‰ˆ 0.707 + """ + vec = [0.0] * dim + vec[dominant_axis % dim] = 1.0 + if secondary_axis is not None: + vec[secondary_axis % dim] = 1.0 + norm = math.sqrt(sum(x * x for x in vec)) + return [x / norm for x in vec] + + +@pytest.fixture(scope="module") +def integration_config(): + if not _neo4j_configured: + pytest.skip("Neo4j not configured (need NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)") + return Neo4jGraphDBConfig( + uri=os.getenv("NEO4J_URI"), + user=os.getenv("NEO4J_USER"), + password=os.getenv("NEO4J_PASSWORD"), + db_name=os.getenv("NEO4J_DB_NAME", "neo4j"), + auto_create=False, + use_multi_db=False, + user_name=f"__test_default_{_TEST_RUN_ID}", + embedding_dimension=int(os.getenv("EMBEDDING_DIMENSION", "1536")), + ) + + +@pytest.fixture(scope="module") +def integration_db(integration_config): + from memos.graph_dbs.neo4j import Neo4jGraphDB + + return Neo4jGraphDB(integration_config) + + +@pytest.mark.skipif(not _neo4j_configured, reason="Neo4j not configured") +class TestNeo4jPreFilterIntegration: + """ + Integration test: pre-filtered vector search in a multi-user shared database. + + Uses vector.similarity.cosine() with MATCH + WHERE to pre-filter by user, + guaranteeing that target user's nodes are always considered regardless of + how many other users' nodes exist in the database. + """ + + @pytest.fixture(scope="class", autouse=True) + def seed_and_cleanup(self, integration_db, integration_config): + """ + Seed multi-user test data, then clean up. + + - 50 "other" user nodes: embeddings along axis 0 โ†’ cos_sim โ‰ˆ 1.0 with query + - 3 "target" user nodes: embeddings blended axis 0+1 โ†’ cos_sim โ‰ˆ 0.707 with query + + With pre-filtering, only the target user's 3 nodes are candidates for + similarity computation, so all 3 are always returned. + """ + dim = integration_config.embedding_dimension + now = datetime.now(timezone.utc).isoformat() + + for i in range(50): + other_user = f"{_OTHER_USER_PREFIX}_{i % 10}" + integration_db.add_node( + id=f"__test_other_{_TEST_RUN_ID}_{i}", + memory=f"Other user memory {i}", + metadata={ + "memory_type": "LongTermMemory", + "status": "activated", + "embedding": _make_unit_vector(dim, dominant_axis=0), + "created_at": now, + "updated_at": now, + }, + user_name=other_user, + ) + + for i in range(3): + integration_db.add_node( + id=f"__test_target_{_TEST_RUN_ID}_{i}", + memory=f"Target user memory {i}", + metadata={ + "memory_type": "LongTermMemory", + "status": "activated", + "embedding": _make_unit_vector(dim, dominant_axis=0, secondary_axis=1), + "created_at": now, + "updated_at": now, + }, + user_name=_TARGET_USER, + ) + + yield + + integration_db.clear(user_name=_TARGET_USER) + for i in range(10): + integration_db.clear(user_name=f"{_OTHER_USER_PREFIX}_{i}") + + def test_search_returns_all_target_user_results(self, integration_db, integration_config): + """Pre-filtering guarantees all target user nodes are found.""" + dim = integration_config.embedding_dimension + query_vector = _make_unit_vector(dim, dominant_axis=0) + + results = integration_db.search_by_embedding( + vector=query_vector, + top_k=3, + scope="LongTermMemory", + status="activated", + user_name=_TARGET_USER, + ) + + assert len(results) == 3, ( + f"Pre-filter should return all 3 target user nodes, got {len(results)}. " + "This indicates pre-filtering is not working correctly." + ) + + def test_all_returned_ids_belong_to_target_user(self, integration_db, integration_config): + dim = integration_config.embedding_dimension + query_vector = _make_unit_vector(dim, dominant_axis=0) + + results = integration_db.search_by_embedding( + vector=query_vector, + top_k=3, + scope="LongTermMemory", + status="activated", + user_name=_TARGET_USER, + ) + + for r in results: + assert r["id"].startswith(f"__test_target_{_TEST_RUN_ID}_"), ( + f"Result {r['id']} does not belong to the target user" + ) + + def test_scores_are_positive(self, integration_db, integration_config): + dim = integration_config.embedding_dimension + query_vector = _make_unit_vector(dim, dominant_axis=0) + + results = integration_db.search_by_embedding( + vector=query_vector, + top_k=3, + scope="LongTermMemory", + status="activated", + user_name=_TARGET_USER, + ) + + for r in results: + assert r["score"] > 0, f"Score should be positive, got {r['score']}" diff --git a/tests/llms/test_minimax.py b/tests/llms/test_minimax.py new file mode 100644 index 000000000..d984adcef --- /dev/null +++ b/tests/llms/test_minimax.py @@ -0,0 +1,114 @@ +import unittest + +from types import SimpleNamespace +from unittest.mock import MagicMock + +from memos.configs.llm import MinimaxLLMConfig +from memos.llms.minimax import MinimaxLLM + + +class TestMinimaxLLM(unittest.TestCase): + def test_minimax_llm_generate_with_and_without_think_prefix(self): + """Test MinimaxLLM generate method with and without tag removal.""" + + # Simulated full content including tag + full_content = "Hello from MiniMax!" + reasoning_content = "Thinking in progress..." + + # Mock response object + mock_response = MagicMock() + mock_response.model_dump_json.return_value = '{"mock": "true"}' + mock_response.choices[0].message.content = full_content + mock_response.choices[0].message.reasoning_content = reasoning_content + + # Config with think prefix preserved + config_with_think = MinimaxLLMConfig.model_validate( + { + "model_name_or_path": "MiniMax-M2.7", + "temperature": 0.7, + "max_tokens": 512, + "top_p": 0.9, + "api_key": "sk-test", + "api_base": "https://api.minimax.io/v1", + "remove_think_prefix": False, + } + ) + llm_with_think = MinimaxLLM(config_with_think) + llm_with_think.client.chat.completions.create = MagicMock(return_value=mock_response) + + output_with_think = llm_with_think.generate([{"role": "user", "content": "Hello"}]) + self.assertEqual(output_with_think, f"{reasoning_content}{full_content}") + + # Config with think tag removed + config_without_think = config_with_think.model_copy(update={"remove_think_prefix": True}) + llm_without_think = MinimaxLLM(config_without_think) + llm_without_think.client.chat.completions.create = MagicMock(return_value=mock_response) + + output_without_think = llm_without_think.generate([{"role": "user", "content": "Hello"}]) + self.assertEqual(output_without_think, full_content) + + def test_minimax_llm_generate_stream(self): + """Test MinimaxLLM generate_stream with content chunks.""" + + def make_chunk(delta_dict): + # Create a simulated stream chunk with delta fields + delta = SimpleNamespace(**delta_dict) + choice = SimpleNamespace(delta=delta) + return SimpleNamespace(choices=[choice]) + + # Simulate chunks: content only (MiniMax standard response) + mock_stream_chunks = [ + make_chunk({"content": "Hello"}), + make_chunk({"content": ", "}), + make_chunk({"content": "MiniMax!"}), + ] + + mock_chat_completions_create = MagicMock(return_value=iter(mock_stream_chunks)) + + config = MinimaxLLMConfig.model_validate( + { + "model_name_or_path": "MiniMax-M2.7", + "temperature": 0.7, + "max_tokens": 512, + "top_p": 0.9, + "api_key": "sk-test", + "api_base": "https://api.minimax.io/v1", + "remove_think_prefix": False, + } + ) + llm = MinimaxLLM(config) + llm.client.chat.completions.create = mock_chat_completions_create + + messages = [{"role": "user", "content": "Say hello"}] + streamed = list(llm.generate_stream(messages)) + full_output = "".join(streamed) + + self.assertEqual(full_output, "Hello, MiniMax!") + + def test_minimax_llm_config_defaults(self): + """Test MinimaxLLMConfig default values.""" + config = MinimaxLLMConfig.model_validate( + { + "model_name_or_path": "MiniMax-M2.7", + "api_key": "sk-test", + } + ) + self.assertEqual(config.api_base, "https://api.minimax.io/v1") + self.assertEqual(config.temperature, 0.7) + self.assertEqual(config.max_tokens, 8192) + + def test_minimax_llm_config_custom_values(self): + """Test MinimaxLLMConfig with custom values.""" + config = MinimaxLLMConfig.model_validate( + { + "model_name_or_path": "MiniMax-M2.7-highspeed", + "api_key": "sk-test", + "api_base": "https://custom.api.minimax.io/v1", + "temperature": 0.5, + "max_tokens": 2048, + } + ) + self.assertEqual(config.model_name_or_path, "MiniMax-M2.7-highspeed") + self.assertEqual(config.api_base, "https://custom.api.minimax.io/v1") + self.assertEqual(config.temperature, 0.5) + self.assertEqual(config.max_tokens, 2048) diff --git a/tests/test_cli.py b/tests/test_cli.py index 9750af121..a1e423e4f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -16,13 +16,13 @@ class TestExportOpenAPI: """Test the export_openapi function.""" - @patch("memos.api.start_api.app") + @patch("memos.cli.get_openapi_app") @patch("builtins.open", new_callable=mock_open) @patch("os.makedirs") def test_export_openapi_success(self, mock_makedirs, mock_file, mock_app): """Test successful OpenAPI export.""" mock_openapi_data = {"openapi": "3.0.0", "info": {"title": "Test API"}} - mock_app.openapi.return_value = mock_openapi_data + mock_app.return_value.openapi.return_value = mock_openapi_data result = export_openapi("/test/path/openapi.json") @@ -30,11 +30,11 @@ def test_export_openapi_success(self, mock_makedirs, mock_file, mock_app): mock_makedirs.assert_called_once_with("/test/path", exist_ok=True) mock_file.assert_called_once_with("/test/path/openapi.json", "w") - @patch("memos.api.start_api.app") + @patch("memos.cli.get_openapi_app") @patch("builtins.open", side_effect=OSError("Permission denied")) def test_export_openapi_error(self, mock_file, mock_app): """Test OpenAPI export when file writing fails.""" - mock_app.openapi.return_value = {"test": "data"} + mock_app.return_value.openapi.return_value = {"test": "data"} with pytest.raises(IOError): export_openapi("/invalid/path/openapi.json")