diff --git a/.gitignore b/.gitignore
index ab7848f74..c972d59de 100644
--- a/.gitignore
+++ b/.gitignore
@@ -234,6 +234,6 @@ apps/openwork-memos-integration/apps/desktop/public/assets/usecases/
# Outputs and Evaluation Results
outputs
-evaluation/data/temporal_locomo
+evaluation/data/
test_add_pipeline.py
test_file_pipeline.py
diff --git a/Makefile b/Makefile
index 788504a73..eb22e241d 100644
--- a/Makefile
+++ b/Makefile
@@ -36,7 +36,7 @@ pre_commit:
poetry run pre-commit run -a
serve:
- poetry run uvicorn memos.api.start_api:app
+ poetry run uvicorn memos.api.server_api:app
openapi:
poetry run memos export_openapi --output docs/openapi.json
diff --git a/README.md b/README.md
index a7b05d683..82ec02772 100644
--- a/README.md
+++ b/README.md
@@ -75,7 +75,7 @@
- [**72% lower token usage**](https://x.com/MemOS_dev/status/2020854044583924111) โ intelligent memory retrieval instead of loading full chat history
- [**Multi-agent memory sharing**](https://x.com/MemOS_dev/status/2020538135487062094) โ multi-instance agents share memory via same user_id, automatic context handoff
-Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/)
+Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/)
Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin)
### ๐ง Local Plugin โ 100% On-Device Memory
@@ -84,7 +84,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
- **Hybrid search + task & skill evolution** โ FTS5 + vector search, auto task summarization, reusable skills that self-upgrade
- **Multi-agent collaboration + Memory Viewer** โ memory isolation, skill sharing, full web dashboard with 7 management pages
- ๐ [Homepage](https://memos-claw.openmem.net) ยท
+ ๐ [Homepage](https://memos-claw.openmem.net) ยท
๐ [Documentation](https://memos-claw.openmem.net/docs/index.html) ยท ๐ฆ [NPM](https://www.npmjs.com/package/@memtensor/memos-local-openclaw-plugin)
## ๐ MemOS: Memory Operating System for AI Agents
@@ -104,10 +104,10 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
### News
-- **2026-03-08** ยท ๐ฆ **MemOS OpenClaw Plugin โ Cloud & Local**
+- **2026-03-08** ยท ๐ฆ **MemOS OpenClaw Plugin โ Cloud & Local**
Official OpenClaw memory plugins launched. **Cloud Plugin**: hosted memory service with 72% lower token usage and multi-agent memory sharing ([MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin)). **Local Plugin** (`v1.0.0`): 100% on-device memory with persistent SQLite, hybrid search (FTS5 + vector), task summarization & skill evolution, multi-agent collaboration, and a full Memory Viewer dashboard.
-- **2025-12-24** ยท ๐ **MemOS v2.0: Stardust (ๆๅฐ) Release**
+- **2025-12-24** ยท ๐ **MemOS v2.0: Stardust (ๆๅฐ) Release**
Comprehensive KB (doc/URL parsing + cross-project sharing), memory feedback & precise deletion, multi-modal memory (images/charts), tool memory for agent planning, Redis Streams scheduling + DB optimizations, streaming/non-streaming chat, MCP upgrade, and lightweight quick/full deployment.
โจ New Features
@@ -155,7 +155,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
- **2025-08-07** ยท ๐ **MemOS v1.0.0 (MemCube) Release**
- First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, NebulaGraph support, improved search capabilities, and the official Playground launch.
+ First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, improved search capabilities, and the official Playground launch.
โจ New Features
@@ -176,7 +176,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
**Plaintext Memory**
- Integrated internet search with Bocha.
- - Added support for Nebula database.
+ - Expanded graph database support.
- Added contextual understanding for the tree-structured plaintext memory search interface.
@@ -188,7 +188,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
- Fixed the concat_cache method.
**Plaintext Memory**
- - Fixed Nebula search-related issues.
+ - Fixed graph search-related issues.
@@ -224,6 +224,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
2. Configure `docker/.env.example` and copy to `MemOS/.env`
- The `OPENAI_API_KEY`,`MOS_EMBEDDER_API_KEY`,`MEMRADER_API_KEY` and others can be applied for through [`BaiLian`](https://bailian.console.aliyun.com/?spm=a2c4g.11186623.0.0.2f2165b08fRk4l&tab=api#/api).
- Fill in the corresponding configuration in the `MemOS/.env` file.
+ - Supported LLM providers: **OpenAI**, **Azure OpenAI**, **Qwen (DashScope)**, **DeepSeek**, **MiniMax**, **Ollama**, **HuggingFace**, **vLLM**. Set `MOS_CHAT_MODEL_PROVIDER` to select the backend (e.g., `openai`, `qwen`, `deepseek`, `minimax`).
3. Start the service.
- Launch via Docker
diff --git a/apps/memos-local-openclaw/.env.example b/apps/memos-local-openclaw/.env.example
index bfb409298..ce292ba80 100644
--- a/apps/memos-local-openclaw/.env.example
+++ b/apps/memos-local-openclaw/.env.example
@@ -18,6 +18,10 @@ SUMMARIZER_TEMPERATURE=0
# Port for the web-based Memory Viewer (default: 18799)
# VIEWER_PORT=18799
+# โโโ Tavily Search (optional) โโโ
+# API key for Tavily web search (get from https://app.tavily.com)
+# TAVILY_API_KEY=tvly-your-tavily-api-key
+
# โโโ Telemetry (opt-out) โโโ
# Anonymous usage analytics to help improve the plugin.
# No memory content, queries, or personal data is ever sent โ only tool names, latencies, and version info.
diff --git a/docker/.env.example b/docker/.env.example
index 3674cd69b..c9d8e714e 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -25,9 +25,12 @@ MOS_MAX_TOKENS=2048
# Top-P for LLM in the Product API
MOS_TOP_P=0.9
# LLM for the Product API backend
-MOS_CHAT_MODEL_PROVIDER=openai # openai | huggingface | vllm
+MOS_CHAT_MODEL_PROVIDER=openai # openai | huggingface | vllm | minimax
OPENAI_API_KEY=sk-xxx # [required] when provider=openai
OPENAI_API_BASE=https://api.openai.com/v1 # [required] base for the key
+# MiniMax LLM (when provider=minimax)
+# MINIMAX_API_KEY=your-minimax-api-key # [required] when provider=minimax
+# MINIMAX_API_BASE=https://api.minimax.io/v1 # base for MiniMax API
## MemReader / retrieval LLM
MEMRADER_MODEL=gpt-4o-mini
@@ -80,8 +83,12 @@ EMBEDDING_MODEL=nomic-embed-text:latest
## Internet search & preference memory
# Enable web search
ENABLE_INTERNET=false
+# Internet search backend (bocha | tavily)
+INTERNET_SEARCH_BACKEND=bocha
# API key for BOCHA Search
-BOCHA_API_KEY= # required if ENABLE_INTERNET=true
+BOCHA_API_KEY= # required if ENABLE_INTERNET=true and backend=bocha
+# API key for Tavily Search
+TAVILY_API_KEY= # required if ENABLE_INTERNET=true and backend=tavily
# default search mode
SEARCH_MODE=fast # fast | fine | mixture
# Slow retrieval strategy configuration, rewrite is the rewrite strategy
@@ -127,7 +134,7 @@ MEMSCHEDULER_USE_REDIS_QUEUE=false
## Graph / vector stores
# Neo4j database selection mode
-NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | nebular | polardb
+NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | polardb | postgres
# Neo4j database url
NEO4J_URI=bolt://localhost:7687 # required when backend=neo4j*
# Neo4j database user
diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml
index 0a8e2c634..6805ec781 100644
--- a/docker/docker-compose.yml
+++ b/docker/docker-compose.yml
@@ -26,7 +26,7 @@ services:
- memos_network
neo4j:
- image: neo4j:5.26.4
+ image: neo4j:5.26.6
container_name: neo4j-docker
ports:
- "7474:7474" # HTTP
diff --git a/docker/requirements-full.txt b/docker/requirements-full.txt
index a14257a76..b148d43d6 100644
--- a/docker/requirements-full.txt
+++ b/docker/requirements-full.txt
@@ -185,3 +185,4 @@ py-key-value-shared==0.2.8
PyJWT==2.10.1
pytest==9.0.2
alibabacloud-oss-v2==1.2.2
+tavily-python==0.5.0
diff --git a/docker/requirements.txt b/docker/requirements.txt
index 340f4e140..988e64b83 100644
--- a/docker/requirements.txt
+++ b/docker/requirements.txt
@@ -124,3 +124,4 @@ uvloop==0.22.1; sys_platform != 'win32'
watchfiles==1.1.1
websockets==15.0.1
alibabacloud-oss-v2==1.2.2
+tavily-python==0.5.0
diff --git a/examples/basic_modules/llm.py b/examples/basic_modules/llm.py
index fb157c991..3fd7352c7 100644
--- a/examples/basic_modules/llm.py
+++ b/examples/basic_modules/llm.py
@@ -164,7 +164,37 @@
print("Scenario 6:", resp)
-# Scenario 7: Using LLMFactory with Deepseek-chat + reasoning + CoT + streaming
+# Scenario 7: Using LLMFactory with MiniMax (OpenAI-compatible API)
+# Prerequisites:
+# 1. Get your API key from the MiniMax platform.
+# 2. Available models: MiniMax-M2.7 (flagship), MiniMax-M2.7-highspeed (low-latency),
+# MiniMax-M2.5, MiniMax-M2.5-highspeed.
+
+cfg_mm = LLMConfigFactory.model_validate(
+ {
+ "backend": "minimax",
+ "config": {
+ "model_name_or_path": "MiniMax-M2.7",
+ "api_key": "your-minimax-api-key",
+ "api_base": "https://api.minimax.io/v1",
+ "temperature": 0.7,
+ "max_tokens": 1024,
+ },
+ }
+)
+llm = LLMFactory.from_config(cfg_mm)
+messages = [{"role": "user", "content": "Hello, who are you"}]
+resp = llm.generate(messages)
+print("Scenario 7:", resp)
+print("==" * 20)
+
+print("Scenario 7 (streaming):\n")
+for chunk in llm.generate_stream(messages):
+ print(chunk, end="")
+print("\n" + "==" * 20)
+
+
+# Scenario 8: Using LLMFactory with DeepSeek-chat + reasoning + CoT + streaming
cfg2 = LLMConfigFactory.model_validate(
{
@@ -186,7 +216,7 @@
"content": "Explain how to solve this problem step-by-step. Be explicit in your thinking process. Question: If a train travels from city A to city B at 60 mph and returns at 40 mph, what is its average speed for the entire trip? Let's think step by step.",
},
]
-print("Scenario 7:\n")
+print("Scenario 8:\n")
for chunk in llm.generate_stream(messages):
print(chunk, end="")
print("==" * 20)
diff --git a/examples/basic_modules/neo4j_example.py b/examples/basic_modules/neo4j_example.py
index e1c0df317..2a5e88749 100644
--- a/examples/basic_modules/neo4j_example.py
+++ b/examples/basic_modules/neo4j_example.py
@@ -2,6 +2,8 @@
from datetime import datetime
+from dotenv import load_dotenv
+
from memos.configs.embedder import EmbedderConfigFactory
from memos.configs.graph_db import GraphDBConfigFactory
from memos.embedders.factory import EmbedderFactory
@@ -9,14 +11,27 @@
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
+load_dotenv()
+
+NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
+NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
+NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "12345678")
+NEO4J_DB_NAME = os.getenv("NEO4J_DB_NAME", "neo4j")
+EMBEDDING_DIMENSION = int(os.getenv("EMBEDDING_DIMENSION", "3072"))
+
+QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
+QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
+
embedder_config = EmbedderConfigFactory.model_validate(
{
- "backend": "universal_api",
+ "backend": os.getenv("MOS_EMBEDDER_BACKEND", "universal_api"),
"config": {
- "provider": "openai",
- "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
- "model_name_or_path": "text-embedding-3-large",
- "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
+ "provider": os.getenv("MOS_EMBEDDER_PROVIDER", "openai"),
+ "api_key": os.getenv("MOS_EMBEDDER_API_KEY", os.getenv("OPENAI_API_KEY", "")),
+ "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"),
+ "base_url": os.getenv(
+ "MOS_EMBEDDER_API_BASE", os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
+ ),
},
}
)
@@ -31,12 +46,12 @@ def get_neo4j_graph(db_name: str = "paper"):
config = GraphDBConfigFactory(
backend="neo4j",
config={
- "uri": "bolt://xxxx:7687",
- "user": "neo4j",
- "password": "xxxx",
+ "uri": NEO4J_URI,
+ "user": NEO4J_USER,
+ "password": NEO4J_PASSWORD,
"db_name": db_name,
"auto_create": True,
- "embedding_dimension": 3072,
+ "embedding_dimension": EMBEDDING_DIMENSION,
"use_multi_db": True,
},
)
@@ -49,12 +64,12 @@ def example_multi_db(db_name: str = "paper"):
config = GraphDBConfigFactory(
backend="neo4j",
config={
- "uri": "bolt://localhost:7687",
- "user": "neo4j",
- "password": "12345678",
+ "uri": NEO4J_URI,
+ "user": NEO4J_USER,
+ "password": NEO4J_PASSWORD,
"db_name": db_name,
"auto_create": True,
- "embedding_dimension": 3072,
+ "embedding_dimension": EMBEDDING_DIMENSION,
"use_multi_db": True,
},
)
@@ -288,14 +303,14 @@ def example_shared_db(db_name: str = "shared-traval-group"):
config = GraphDBConfigFactory(
backend="neo4j",
config={
- "uri": "bolt://localhost:7687",
- "user": "neo4j",
- "password": "12345678",
+ "uri": NEO4J_URI,
+ "user": NEO4J_USER,
+ "password": NEO4J_PASSWORD,
"db_name": db_name,
"user_name": user_name,
"use_multi_db": False,
"auto_create": True,
- "embedding_dimension": 3072,
+ "embedding_dimension": EMBEDDING_DIMENSION,
},
)
# Step 2: Instantiate graph store
@@ -353,12 +368,12 @@ def example_shared_db(db_name: str = "shared-traval-group"):
config_alice = GraphDBConfigFactory(
backend="neo4j",
config={
- "uri": "bolt://localhost:7687",
- "user": "neo4j",
- "password": "12345678",
+ "uri": NEO4J_URI,
+ "user": NEO4J_USER,
+ "password": NEO4J_PASSWORD,
"db_name": db_name,
"user_name": user_list[0],
- "embedding_dimension": 3072,
+ "embedding_dimension": EMBEDDING_DIMENSION,
},
)
graph_alice = GraphStoreFactory.from_config(config_alice)
@@ -382,24 +397,22 @@ def run_user_session(
config = GraphDBConfigFactory(
backend="neo4j-community",
config={
- "uri": "bolt://localhost:7687",
- "user": "neo4j",
- "password": "12345678",
+ "uri": NEO4J_URI,
+ "user": NEO4J_USER,
+ "password": NEO4J_PASSWORD,
"db_name": db_name,
"user_name": user_name,
"use_multi_db": False,
- "auto_create": False, # Neo4j Community does not allow auto DB creation
- "embedding_dimension": 3072,
+ "auto_create": False,
+ "embedding_dimension": EMBEDDING_DIMENSION,
"vec_config": {
- # Pass nested config to initialize external vector DB
- # If you use qdrant, please use Server instead of local mode.
"backend": "qdrant",
"config": {
"collection_name": "neo4j_vec_db",
- "vector_dimension": 3072,
+ "vector_dimension": EMBEDDING_DIMENSION,
"distance_metric": "cosine",
- "host": "localhost",
- "port": 6333,
+ "host": QDRANT_HOST,
+ "port": QDRANT_PORT,
},
},
},
@@ -408,14 +421,14 @@ def run_user_session(
config = GraphDBConfigFactory(
backend="neo4j",
config={
- "uri": "bolt://localhost:7687",
- "user": "neo4j",
- "password": "12345678",
+ "uri": NEO4J_URI,
+ "user": NEO4J_USER,
+ "password": NEO4J_PASSWORD,
"db_name": db_name,
"user_name": user_name,
"use_multi_db": False,
"auto_create": True,
- "embedding_dimension": 3072,
+ "embedding_dimension": EMBEDDING_DIMENSION,
},
)
graph = GraphStoreFactory.from_config(config)
diff --git a/examples/basic_modules/textual_memory_internet_search_example.py b/examples/basic_modules/textual_memory_internet_search_example.py
index 9007d7e67..48c5b5f72 100644
--- a/examples/basic_modules/textual_memory_internet_search_example.py
+++ b/examples/basic_modules/textual_memory_internet_search_example.py
@@ -22,6 +22,8 @@
**Prerequisites:**
- Valid BochaAI API Key (set in environment variable: BOCHA_API_KEY)
+- (Optional) Valid Tavily API Key (set in environment variable: TAVILY_API_KEY)
+ - Get from https://app.tavily.com (1,000 free credits/month)
- (Optional) Valid Google API Key and Search Engine ID for Google Custom Search
- GOOGLE_API_KEY: Get from https://console.cloud.google.com/
- GOOGLE_SEARCH_ENGINE_ID: Get from https://programmablesearchengine.google.com/
@@ -288,19 +290,102 @@
print("\n Get your credentials from:")
print(" https://developers.google.com/custom-search/v1/overview")
+# ============================================================================
+# Step 7: Test Tavily Search API (Optional)
+# ============================================================================
+print("\n" + "=" * 80)
+print("TAVILY SEARCH API TEST")
+print("=" * 80)
+
+tavily_api_key = os.environ.get("TAVILY_API_KEY", "")
+
+if tavily_api_key:
+ print("\n[Step 7.1] Configuring Tavily Search retriever...")
+
+ tavily_retriever_config = InternetRetrieverConfigFactory.model_validate(
+ {
+ "backend": "tavily",
+ "config": {
+ "api_key": tavily_api_key,
+ "max_results": 5,
+ "search_depth": "basic",
+ "include_answer": False,
+ },
+ }
+ )
+
+ print(" Retriever configured: tavily")
+ print(f" Max results: {tavily_retriever_config.config.max_results}")
+
+ print("\n[Step 7.2] Creating Tavily retriever instance...")
+ tavily_retriever = InternetRetrieverFactory.from_config(tavily_retriever_config, embedder)
+ print(" Tavily retriever initialized")
+
+ print("\n[Step 7.3] Performing Tavily web search...")
+ tavily_query = "latest developments in AI 2024"
+ print(f" Query: '{tavily_query}'")
+ print(" Searching via Tavily Search API...\n")
+
+ tavily_results = tavily_retriever.retrieve_from_internet(tavily_query)
+
+ print(f" Tavily search completed! Retrieved {len(tavily_results)} memory items\n")
+
+ print("=" * 80)
+ print("TAVILY SEARCH RESULTS")
+ print("=" * 80)
+
+ if not tavily_results:
+ print("\n No results found from Tavily.")
+ print(" This might indicate:")
+ print(" - Invalid Tavily API key")
+ print(" - Network connectivity issues")
+ else:
+ for idx, item in enumerate(tavily_results, 1):
+ print(f"\n[Tavily Result #{idx}]")
+ print("-" * 80)
+
+ content = item.memory
+ if len(content) > 300:
+ print(f"Content: {content[:300]}...")
+ print(f" (... {len(content) - 300} more characters)")
+ else:
+ print(f"Content: {content}")
+
+ if hasattr(item, "metadata") and item.metadata:
+ metadata = item.metadata
+ if hasattr(metadata, "sources") and metadata.sources:
+ print(f"Source: {metadata.sources[0] if metadata.sources else 'N/A'}")
+
+ print()
+
+ print("=" * 80)
+ print("Tavily Search Test completed!")
+ print("=" * 80)
+else:
+ print("\n Skipping Tavily Search API test")
+ print(" To enable this test, set the following environment variable:")
+ print(" - TAVILY_API_KEY: Your Tavily API key (get from https://app.tavily.com)")
+
print("\n" + "=" * 80)
print("ALL TESTS COMPLETED")
print("=" * 80)
-print("\n๐ก Summary:")
-print(" โ Tested BochaAI web search retriever")
+print("\n Summary:")
+print(" - Tested BochaAI web search retriever")
if google_api_key and google_search_engine_id:
- print(" โ Tested Google Custom Search API")
+ print(" - Tested Google Custom Search API")
else:
- print(" โญ๏ธ Skipped Google Custom Search API (credentials not set)")
-print("\n๐ก Quick Start:")
+ print(" - Skipped Google Custom Search API (credentials not set)")
+if tavily_api_key:
+ print(" - Tested Tavily Search API")
+else:
+ print(" - Skipped Tavily Search API (credentials not set)")
+print("\n Quick Start:")
print(" # Set BochaAI API key")
print(" export BOCHA_API_KEY='sk-your-bocha-api-key'")
print(" ")
+print(" # Set Tavily API key (optional)")
+print(" export TAVILY_API_KEY='tvly-your-tavily-api-key'")
+print(" ")
print(" # Set Google Custom Search credentials (optional)")
print(" export GOOGLE_API_KEY='your-google-api-key'")
print(" export GOOGLE_SEARCH_ENGINE_ID='your-search-engine-id'")
diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py
index 6dbe202c2..d14b6d687 100644
--- a/examples/mem_agent/deepsearch_example.py
+++ b/examples/mem_agent/deepsearch_example.py
@@ -47,7 +47,7 @@ def build_minimal_components():
# Build component configurations using APIConfig methods (like config_builders.py)
- # Graph DB configuration - using APIConfig.get_nebular_config()
+ # Graph DB configuration - using APIConfig graph DB helpers
graph_db_backend = os.getenv("NEO4J_BACKEND", "polardb").lower()
graph_db_backend_map = {
"polardb": APIConfig.get_polardb_config(),
diff --git a/examples/mem_mcp/simple_fastmcp_serve.py b/examples/mem_mcp/simple_fastmcp_serve.py
index 55ad4d84d..5071314e0 100644
--- a/examples/mem_mcp/simple_fastmcp_serve.py
+++ b/examples/mem_mcp/simple_fastmcp_serve.py
@@ -23,7 +23,7 @@ def add_memory(memory_content: str, user_id: str, cube_id: str | None = None):
"""Add memory using the Server API."""
payload = {
"user_id": user_id,
- "messages": memory_content,
+ "messages": [{"role": "user", "content": memory_content}],
"writable_cube_ids": [cube_id] if cube_id else None,
}
try:
diff --git a/poetry.lock b/poetry.lock
index 72049f025..dfea31354 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 2.3.2 and should not be changed by hand.
+# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
[[package]]
name = "absl-py"
@@ -1897,7 +1897,7 @@ files = [
[package.dependencies]
attrs = ">=22.2.0"
-jsonschema-specifications = ">=2023.3.6"
+jsonschema-specifications = ">=2023.03.6"
referencing = ">=0.28.4"
rpds-py = ">=0.7.1"
@@ -2672,7 +2672,6 @@ files = [
{file = "nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1"},
{file = "nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868"},
]
-markers = {main = "extra == \"all\""}
[package.dependencies]
click = "*"
@@ -5476,6 +5475,24 @@ mpmath = ">=1.1.0,<1.4"
[package.extras]
dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
+[[package]]
+name = "tavily-python"
+version = "0.7.23"
+description = "Python wrapper for the Tavily API"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "extra == \"tavily\" or extra == \"all\""
+files = [
+ {file = "tavily_python-0.7.23-py3-none-any.whl", hash = "sha256:52ef85c44b926bce3f257570cd32bc1bd4db54666acf3105617f27411a59e188"},
+ {file = "tavily_python-0.7.23.tar.gz", hash = "sha256:3b92232e0e29ab68898b765f281bb4f2c650b02210b64affbc48e15292e96161"},
+]
+
+[package.dependencies]
+httpx = "*"
+requests = "*"
+tiktoken = ">=0.5.1"
+
[[package]]
name = "tenacity"
version = "9.1.2"
@@ -5504,6 +5521,81 @@ files = [
{file = "threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e"},
]
+[[package]]
+name = "tiktoken"
+version = "0.12.0"
+description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"tavily\" or extra == \"all\""
+files = [
+ {file = "tiktoken-0.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3de02f5a491cfd179aec916eddb70331814bd6bf764075d39e21d5862e533970"},
+ {file = "tiktoken-0.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b6cfb6d9b7b54d20af21a912bfe63a2727d9cfa8fbda642fd8322c70340aad16"},
+ {file = "tiktoken-0.12.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:cde24cdb1b8a08368f709124f15b36ab5524aac5fa830cc3fdce9c03d4fb8030"},
+ {file = "tiktoken-0.12.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6de0da39f605992649b9cfa6f84071e3f9ef2cec458d08c5feb1b6f0ff62e134"},
+ {file = "tiktoken-0.12.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6faa0534e0eefbcafaccb75927a4a380463a2eaa7e26000f0173b920e98b720a"},
+ {file = "tiktoken-0.12.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:82991e04fc860afb933efb63957affc7ad54f83e2216fe7d319007dab1ba5892"},
+ {file = "tiktoken-0.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:6fb2995b487c2e31acf0a9e17647e3b242235a20832642bb7a9d1a181c0c1bb1"},
+ {file = "tiktoken-0.12.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6e227c7f96925003487c33b1b32265fad2fbcec2b7cf4817afb76d416f40f6bb"},
+ {file = "tiktoken-0.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c06cf0fcc24c2cb2adb5e185c7082a82cba29c17575e828518c2f11a01f445aa"},
+ {file = "tiktoken-0.12.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:f18f249b041851954217e9fd8e5c00b024ab2315ffda5ed77665a05fa91f42dc"},
+ {file = "tiktoken-0.12.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:47a5bc270b8c3db00bb46ece01ef34ad050e364b51d406b6f9730b64ac28eded"},
+ {file = "tiktoken-0.12.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:508fa71810c0efdcd1b898fda574889ee62852989f7c1667414736bcb2b9a4bd"},
+ {file = "tiktoken-0.12.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a1af81a6c44f008cba48494089dd98cccb8b313f55e961a52f5b222d1e507967"},
+ {file = "tiktoken-0.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:3e68e3e593637b53e56f7237be560f7a394451cb8c11079755e80ae64b9e6def"},
+ {file = "tiktoken-0.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b97f74aca0d78a1ff21b8cd9e9925714c15a9236d6ceacf5c7327c117e6e21e8"},
+ {file = "tiktoken-0.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2b90f5ad190a4bb7c3eb30c5fa32e1e182ca1ca79f05e49b448438c3e225a49b"},
+ {file = "tiktoken-0.12.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:65b26c7a780e2139e73acc193e5c63ac754021f160df919add909c1492c0fb37"},
+ {file = "tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:edde1ec917dfd21c1f2f8046b86348b0f54a2c0547f68149d8600859598769ad"},
+ {file = "tiktoken-0.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:35a2f8ddd3824608b3d650a000c1ef71f730d0c56486845705a8248da00f9fe5"},
+ {file = "tiktoken-0.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83d16643edb7fa2c99eff2ab7733508aae1eebb03d5dfc46f5565862810f24e3"},
+ {file = "tiktoken-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc5288f34a8bc02e1ea7047b8d041104791d2ddbf42d1e5fa07822cbffe16bd"},
+ {file = "tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3"},
+ {file = "tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160"},
+ {file = "tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa"},
+ {file = "tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be"},
+ {file = "tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a"},
+ {file = "tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3"},
+ {file = "tiktoken-0.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:b8a0cd0c789a61f31bf44851defbd609e8dd1e2c8589c614cc1060940ef1f697"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:dc2dd125a62cb2b3d858484d6c614d136b5b848976794edfb63688d539b8b93f"},
+ {file = "tiktoken-0.12.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a90388128df3b3abeb2bfd1895b0681412a8d7dc644142519e6f0a97c2111646"},
+ {file = "tiktoken-0.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:da900aa0ad52247d8794e307d6446bd3cdea8e192769b56276695d34d2c9aa88"},
+ {file = "tiktoken-0.12.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:285ba9d73ea0d6171e7f9407039a290ca77efcdb026be7769dccc01d2c8d7fff"},
+ {file = "tiktoken-0.12.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:d186a5c60c6a0213f04a7a802264083dea1bbde92a2d4c7069e1a56630aef830"},
+ {file = "tiktoken-0.12.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:604831189bd05480f2b885ecd2d1986dc7686f609de48208ebbbddeea071fc0b"},
+ {file = "tiktoken-0.12.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8f317e8530bb3a222547b85a58583238c8f74fd7a7408305f9f63246d1a0958b"},
+ {file = "tiktoken-0.12.0-cp314-cp314-win_amd64.whl", hash = "sha256:399c3dd672a6406719d84442299a490420b458c44d3ae65516302a99675888f3"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2c714c72bc00a38ca969dae79e8266ddec999c7ceccd603cc4f0d04ccd76365"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cbb9a3ba275165a2cb0f9a83f5d7025afe6b9d0ab01a22b50f0e74fee2ad253e"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:dfdfaa5ffff8993a3af94d1125870b1d27aed7cb97aa7eb8c1cefdbc87dbee63"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:584c3ad3d0c74f5269906eb8a659c8bfc6144a52895d9261cdaf90a0ae5f4de0"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:54c891b416a0e36b8e2045b12b33dd66fb34a4fe7965565f1b482da50da3e86a"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5edb8743b88d5be814b1a8a8854494719080c28faaa1ccbef02e87354fe71ef0"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-win_amd64.whl", hash = "sha256:f61c0aea5565ac82e2ec50a05e02a6c44734e91b51c10510b084ea1b8e633a71"},
+ {file = "tiktoken-0.12.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:d51d75a5bffbf26f86554d28e78bfb921eae998edc2675650fd04c7e1f0cdc1e"},
+ {file = "tiktoken-0.12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:09eb4eae62ae7e4c62364d9ec3a57c62eea707ac9a2b2c5d6bd05de6724ea179"},
+ {file = "tiktoken-0.12.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:df37684ace87d10895acb44b7f447d4700349b12197a526da0d4a4149fde074c"},
+ {file = "tiktoken-0.12.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:4c9614597ac94bb294544345ad8cf30dac2129c05e2db8dc53e082f355857af7"},
+ {file = "tiktoken-0.12.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:20cf97135c9a50de0b157879c3c4accbb29116bcf001283d26e073ff3b345946"},
+ {file = "tiktoken-0.12.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:15d875454bbaa3728be39880ddd11a5a2a9e548c29418b41e8fd8a767172b5ec"},
+ {file = "tiktoken-0.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:2cff3688ba3c639ebe816f8d58ffbbb0aa7433e23e08ab1cade5d175fc973fb3"},
+ {file = "tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931"},
+]
+
+[package.dependencies]
+regex = ">=2022.1.18"
+requests = ">=2.26.0"
+
+[package.extras]
+blobfile = ["blobfile (>=2)"]
+
[[package]]
name = "tokenizers"
version = "0.21.2"
@@ -6544,15 +6636,16 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
cffi = ["cffi (>=1.11)"]
[extras]
-all = ["alibabacloud-oss-v2", "cachetools", "chonkie", "datasketch", "jieba", "langchain-text-splitters", "markitdown", "neo4j", "nltk", "pika", "pymilvus", "pymysql", "qdrant-client", "rake-nltk", "rank-bm25", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"]
+all = ["alibabacloud-oss-v2", "cachetools", "chonkie", "datasketch", "jieba", "langchain-text-splitters", "markitdown", "neo4j", "nltk", "pika", "pymilvus", "pymysql", "qdrant-client", "rake-nltk", "rank-bm25", "redis", "schedule", "sentence-transformers", "tavily-python", "torch", "volcengine-python-sdk"]
mem-reader = ["chonkie", "langchain-text-splitters", "markitdown"]
mem-scheduler = ["pika", "redis"]
mem-user = ["pymysql"]
pref-mem = ["datasketch", "pymilvus"]
skill-mem = ["alibabacloud-oss-v2"]
+tavily = ["tavily-python"]
tree-mem = ["neo4j", "schedule"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
-content-hash = "e0427aa672e57215033fe964847474521abf61b0a63f443744a6ec0b8c5ff2e2"
+content-hash = "86c718082278e9df7b994b87eb4b75c64ee4b5353954bf16dbec371d8264ce0a"
diff --git a/pyproject.toml b/pyproject.toml
index de8e66ad1..a359ee498 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -102,6 +102,11 @@ skill-mem = [
"alibabacloud-oss-v2 (>=1.2.2,<1.2.3)",
]
+# Tavily Search
+tavily = [
+ "tavily-python (>=0.5.0,<1.0.0)",
+]
+
# All optional dependencies
# Allow users to install with `pip install MemoryOS[all]`
all = [
@@ -129,6 +134,7 @@ all = [
"nltk (>=3.9.1,<4.0.0)",
"rake-nltk (>=1.0.6,<1.1.0)",
"alibabacloud-oss-v2 (>=1.2.2,<1.2.3)",
+ "tavily-python (>=0.5.0,<1.0.0)",
# Uncategorized dependencies
]
diff --git a/src/memos/api/config.py b/src/memos/api/config.py
index 87f1efd8e..c68deae5a 100644
--- a/src/memos/api/config.py
+++ b/src/memos/api/config.py
@@ -284,6 +284,20 @@ def qwen_config() -> dict[str, Any]:
"remove_think_prefix": True,
}
+ @staticmethod
+ def minimax_config() -> dict[str, Any]:
+ """Get MiniMax configuration."""
+ return {
+ "model_name_or_path": os.getenv("MOS_CHAT_MODEL", "MiniMax-M2.7"),
+ "temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.8")),
+ "max_tokens": int(os.getenv("MOS_MAX_TOKENS", "8000")),
+ "top_p": float(os.getenv("MOS_TOP_P", "0.9")),
+ "top_k": int(os.getenv("MOS_TOP_K", "50")),
+ "remove_think_prefix": True,
+ "api_key": os.getenv("MINIMAX_API_KEY", "your-api-key-here"),
+ "api_base": os.getenv("MINIMAX_API_BASE", "https://api.minimax.io/v1"),
+ }
+
@staticmethod
def vllm_config() -> dict[str, Any]:
"""Get Qwen configuration."""
@@ -626,7 +640,26 @@ def get_oss_config() -> dict[str, Any] | None:
return config
def get_internet_config() -> dict[str, Any]:
- """Get embedder configuration."""
+ """Get internet retriever configuration.
+
+ Supports backends: bocha (default), tavily, google, bing, xinyu.
+ Set INTERNET_SEARCH_BACKEND env var to choose the backend.
+ For Tavily, set TAVILY_API_KEY env var.
+ For Bocha, set BOCHA_API_KEY env var.
+ """
+ backend = os.getenv("INTERNET_SEARCH_BACKEND", "bocha").lower()
+
+ if backend == "tavily":
+ return {
+ "backend": "tavily",
+ "config": {
+ "api_key": os.getenv("TAVILY_API_KEY", ""),
+ "max_results": 10,
+ "search_depth": os.getenv("TAVILY_SEARCH_DEPTH", "basic"),
+ "include_answer": os.getenv("TAVILY_INCLUDE_ANSWER", "false").lower() == "true",
+ },
+ }
+
reader_config = APIConfig.get_reader_config()
return {
"backend": "bocha",
@@ -741,21 +774,6 @@ def get_neo4j_shared_config(user_id: str | None = None) -> dict[str, Any]:
"embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)),
}
- @staticmethod
- def get_nebular_config(user_id: str | None = None) -> dict[str, Any]:
- """Get Nebular configuration."""
- return {
- "uri": json.loads(os.getenv("NEBULAR_HOSTS", '["localhost"]')),
- "user": os.getenv("NEBULAR_USER", "root"),
- "password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"),
- "space": os.getenv("NEBULAR_SPACE", "shared-tree-textual-memory"),
- "user_name": f"memos{user_id.replace('-', '')}",
- "use_multi_db": False,
- "auto_create": True,
- "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)),
- }
-
- @staticmethod
def get_milvus_config():
return {
"collection_name": [
@@ -916,12 +934,14 @@ def get_product_default_config() -> dict[str, Any]:
openai_config = APIConfig.get_openai_config()
qwen_config = APIConfig.qwen_config()
vllm_config = APIConfig.vllm_config()
+ minimax_config = APIConfig.minimax_config()
reader_config = APIConfig.get_reader_config()
backend_model = {
"openai": openai_config,
"huggingface": qwen_config,
"vllm": vllm_config,
+ "minimax": minimax_config,
}
backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai")
mysql_config = APIConfig.get_mysql_config()
@@ -1039,6 +1059,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
openai_config = APIConfig.get_openai_config()
qwen_config = APIConfig.qwen_config()
vllm_config = APIConfig.vllm_config()
+ minimax_config = APIConfig.minimax_config()
mysql_config = APIConfig.get_mysql_config()
reader_config = APIConfig.get_reader_config()
backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai")
@@ -1046,6 +1067,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
"openai": openai_config,
"huggingface": qwen_config,
"vllm": vllm_config,
+ "minimax": minimax_config,
}
# Create MOSConfig
config_dict = {
@@ -1103,7 +1125,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id)
neo4j_config = APIConfig.get_neo4j_config(user_id)
- nebular_config = APIConfig.get_nebular_config(user_id)
polardb_config = APIConfig.get_polardb_config(user_id)
internet_config = (
APIConfig.get_internet_config()
@@ -1114,7 +1135,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
graph_db_backend_map = {
"neo4j-community": neo4j_community_config,
"neo4j": neo4j_config,
- "nebular": nebular_config,
"polardb": polardb_config,
"postgres": postgres_config,
}
@@ -1144,9 +1164,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower()
== "true",
"memory_size": {
- "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)),
- "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)),
- "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)),
+ "WorkingMemory": int(os.getenv("MOS_WORKING_MEMORY", 20)),
+ "LongTermMemory": int(os.getenv("MOS_LONGTERM_MEMORY", 1e6)),
+ "UserMemory": int(os.getenv("MOS_USER_MEMORY", 1e6)),
},
"search_strategy": {
"fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"),
@@ -1169,7 +1189,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
}
)
else:
- raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}")
+ raise ValueError(f"Invalid graph DB backend: {graph_db_backend}")
default_mem_cube = GeneralMemCube(default_cube_config)
return default_config, default_mem_cube
@@ -1188,13 +1208,11 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None":
openai_config = APIConfig.get_openai_config()
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default")
neo4j_config = APIConfig.get_neo4j_config(user_id="default")
- nebular_config = APIConfig.get_nebular_config(user_id="default")
polardb_config = APIConfig.get_polardb_config(user_id="default")
postgres_config = APIConfig.get_postgres_config(user_id="default")
graph_db_backend_map = {
"neo4j-community": neo4j_community_config,
"neo4j": neo4j_config,
- "nebular": nebular_config,
"polardb": polardb_config,
"postgres": postgres_config,
}
@@ -1227,9 +1245,9 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None":
== "true",
"internet_retriever": internet_config,
"memory_size": {
- "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)),
- "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)),
- "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)),
+ "WorkingMemory": int(os.getenv("MOS_WORKING_MEMORY", 20)),
+ "LongTermMemory": int(os.getenv("MOS_LONGTERM_MEMORY", 1e6)),
+ "UserMemory": int(os.getenv("MOS_USER_MEMORY", 1e6)),
},
"search_strategy": {
"fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"),
@@ -1253,4 +1271,4 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None":
}
)
else:
- raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}")
+ raise ValueError(f"Invalid graph DB backend: {graph_db_backend}")
diff --git a/src/memos/api/exceptions.py b/src/memos/api/exceptions.py
index 10a14b4d1..15644113c 100644
--- a/src/memos/api/exceptions.py
+++ b/src/memos/api/exceptions.py
@@ -14,13 +14,26 @@ class APIExceptionHandler:
@staticmethod
async def validation_error_handler(request: Request, exc: RequestValidationError):
"""Handle request validation errors."""
- logger.error(f"Validation error: {exc.errors()}")
+ errors = exc.errors()
+ path = request.url.path
+ method = request.method
+
+ readable_errors = []
+ for err in errors:
+ loc = " -> ".join(str(loc_i) for loc_i in err.get("loc", []))
+ readable_errors.append(
+ f"[{loc}] {err.get('msg', 'unknown error')} (type: {err.get('type', 'unknown')})"
+ )
+
+ logger.error(
+ f"Validation error on {method} {path}: {readable_errors}, raw errors: {errors}"
+ )
return JSONResponse(
status_code=422,
content={
"code": 422,
- "message": "Parameter validation error",
- "detail": exc.errors(),
+ "message": f"Parameter validation error on {method} {path}: {'; '.join(readable_errors)}",
+ "detail": errors,
"data": None,
},
)
diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py
index e071eacb3..eb982fd06 100644
--- a/src/memos/api/handlers/base_handler.py
+++ b/src/memos/api/handlers/base_handler.py
@@ -36,7 +36,6 @@ def __init__(
vector_db: Any | None = None,
internet_retriever: Any | None = None,
memory_manager: Any | None = None,
- mos_server: Any | None = None,
feedback_server: Any | None = None,
**kwargs,
):
@@ -54,7 +53,6 @@ def __init__(
vector_db: Vector database instance
internet_retriever: Internet retriever instance
memory_manager: Memory manager instance
- mos_server: MOS server instance
**kwargs: Additional dependencies
"""
self.llm = llm
@@ -68,7 +66,6 @@ def __init__(
self.vector_db = vector_db
self.internet_retriever = internet_retriever
self.memory_manager = memory_manager
- self.mos_server = mos_server
self.feedback_server = feedback_server
# Store any additional dependencies
@@ -158,11 +155,6 @@ def vector_db(self):
"""Get vector database instance."""
return self.deps.vector_db
- @property
- def mos_server(self):
- """Get MOS server instance."""
- return self.deps.mos_server
-
@property
def deepsearch_agent(self):
"""Get deepsearch agent instance."""
diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py
index 7894ff7dc..a01fffef8 100644
--- a/src/memos/api/handlers/component_init.py
+++ b/src/memos/api/handlers/component_init.py
@@ -28,7 +28,6 @@
from memos.log import get_logger
from memos.mem_cube.navie import NaiveMemCube
from memos.mem_feedback.simple_feedback import SimpleMemFeedback
-from memos.mem_os.product_server import MOSServer
from memos.mem_reader.factory import MemReaderFactory
from memos.mem_scheduler.orm_modules.base_model import BaseDBManager
from memos.mem_scheduler.scheduler_factory import SchedulerFactory
@@ -211,15 +210,6 @@ def init_server() -> dict[str, Any]:
logger.debug("Text memory initialized")
- # Initialize MOS Server
- mos_server = MOSServer(
- mem_reader=mem_reader,
- llm=llm,
- online_bot=False,
- )
-
- logger.debug("MOS server initialized")
-
# Create MemCube with pre-initialized memory instances
naive_mem_cube = NaiveMemCube(
text_mem=text_mem,
@@ -304,7 +294,6 @@ def init_server() -> dict[str, Any]:
"internet_retriever": internet_retriever,
"memory_manager": memory_manager,
"default_cube_config": default_cube_config,
- "mos_server": mos_server,
"mem_scheduler": mem_scheduler,
"naive_mem_cube": naive_mem_cube,
"searcher": searcher,
diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py
index 5655bf1e5..d29429fc9 100644
--- a/src/memos/api/handlers/config_builders.py
+++ b/src/memos/api/handlers/config_builders.py
@@ -39,13 +39,14 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
graph_db_backend_map = {
"neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id),
"neo4j": APIConfig.get_neo4j_config(user_id=user_id),
- "nebular": APIConfig.get_nebular_config(user_id=user_id),
"polardb": APIConfig.get_polardb_config(user_id=user_id),
"postgres": APIConfig.get_postgres_config(user_id=user_id),
}
# Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars
- graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower()
+ graph_db_backend = os.getenv(
+ "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")
+ ).lower()
return GraphDBConfigFactory.model_validate(
{
"backend": graph_db_backend,
diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py
index 72c6a3da8..dc052fdbe 100644
--- a/src/memos/api/handlers/memory_handler.py
+++ b/src/memos/api/handlers/memory_handler.py
@@ -314,26 +314,99 @@ def handle_get_memories(
return GetMemoryResponse(message="Memories retrieved successfully", data=filtered_results)
+def _build_quick_delete_constraints(delete_mem_req: DeleteMemoryRequest) -> dict[str, Any]:
+ """Build fast-delete constraints from request-level fields."""
+ constraints: dict[str, Any] = {}
+ if delete_mem_req.user_id is not None:
+ constraints["user_id"] = delete_mem_req.user_id
+ if delete_mem_req.session_id is not None:
+ constraints["session_id"] = delete_mem_req.session_id
+ return constraints
+
+
+def _merge_delete_filter(
+ base_filter: dict[str, Any] | None,
+ constraints: dict[str, Any],
+) -> dict[str, Any]:
+ """Merge user/session constraints into an existing filter."""
+ if not constraints:
+ return base_filter or {}
+ if base_filter is None:
+ return {"and": [constraints.copy()]}
+
+ if not base_filter:
+ return {"and": [constraints.copy()]}
+
+ if "and" in base_filter:
+ and_conditions = base_filter.get("and")
+ if not isinstance(and_conditions, list):
+ raise ValueError("Invalid filter format: 'and' must be a list")
+ return {"and": [*and_conditions, constraints.copy()]}
+
+ if "or" in base_filter:
+ or_conditions = base_filter.get("or")
+ if not isinstance(or_conditions, list):
+ raise ValueError("Invalid filter format: 'or' must be a list")
+
+ merged_or_conditions: list[dict[str, Any]] = []
+ for condition in or_conditions:
+ if not isinstance(condition, dict):
+ raise ValueError("Invalid filter format: each 'or' condition must be a dict")
+ merged_condition = condition.copy()
+ for key, value in constraints.items():
+ if key in merged_condition and merged_condition[key] != value:
+ raise ValueError(
+ f"Conflicting filter condition for '{key}'. "
+ "Please merge it manually into request.filter."
+ )
+ merged_condition[key] = value
+ merged_or_conditions.append(merged_condition)
+
+ return {"or": merged_or_conditions}
+
+ # For plain dict filters, keep strict AND semantics explicitly.
+ return {"and": [base_filter.copy(), constraints.copy()]}
+
+
def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: NaiveMemCube):
"""
Handler for deleting memories.
Now unified to delete from text_mem only (includes preferences).
"""
logger.info(
- "[Delete memory request] writable_cube_ids: %s, memory_ids: %s, auto_cleanup_working: %s",
+ "[Delete memory request] writable_cube_ids: %s, memory_ids: %s, file_ids: %s, auto_cleanup_working: %s"
+ "has_filter: %s, user_id: %s, session_id: %s",
delete_mem_req.writable_cube_ids,
delete_mem_req.memory_ids,
+ delete_mem_req.file_ids,
+ delete_mem_req.filter is not None,
+ delete_mem_req.user_id,
+ delete_mem_req.session_id,
getattr(delete_mem_req, "auto_cleanup_working", False),
)
- # Validate that only one of memory_ids, file_ids, or filter is provided
+ quick_constraints = _build_quick_delete_constraints(delete_mem_req)
+ has_non_empty_filter = bool(delete_mem_req.filter)
+ has_filter_mode = has_non_empty_filter or bool(quick_constraints)
+
+ # Reject empty filter dict when no quick constraints are provided.
+ if delete_mem_req.filter is not None and not has_non_empty_filter and not quick_constraints:
+ return DeleteMemoryResponse(
+ message="filter cannot be empty. Provide a non-empty filter or user_id/session_id.",
+ data={"status": "failure"},
+ )
+
+ # Validate that only one mode is provided: memory_ids, file_ids, or filter-mode.
provided_params = [
delete_mem_req.memory_ids is not None,
delete_mem_req.file_ids is not None,
- delete_mem_req.filter is not None,
+ has_filter_mode,
]
if sum(provided_params) != 1:
return DeleteMemoryResponse(
- message="Exactly one of memory_ids, file_ids, or filter must be provided",
+ message=(
+ "Exactly one delete mode must be provided: "
+ "memory_ids, file_ids, or filter/user_id/session_id."
+ ),
data={"status": "failure"},
)
@@ -370,8 +443,14 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube:
naive_mem_cube.text_mem.delete_by_filter(
writable_cube_ids=delete_mem_req.writable_cube_ids, file_ids=delete_mem_req.file_ids
)
- elif delete_mem_req.filter is not None:
- naive_mem_cube.text_mem.delete_by_filter(filter=delete_mem_req.filter)
+ elif has_filter_mode:
+ merged_filter = _merge_delete_filter(delete_mem_req.filter, quick_constraints)
+ naive_mem_cube.text_mem.delete_by_filter(
+ writable_cube_ids=delete_mem_req.writable_cube_ids,
+ filter=merged_filter,
+ )
+ if naive_mem_cube.pref_mem is not None:
+ naive_mem_cube.pref_mem.delete_by_filter(filter=merged_filter)
# After main deletion, optionally clean up related WorkingMemory nodes.
if working_ids_to_delete:
diff --git a/src/memos/api/product_api.py b/src/memos/api/product_api.py
deleted file mode 100644
index ec5cccae1..000000000
--- a/src/memos/api/product_api.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import logging
-
-from fastapi import FastAPI
-
-from memos.api.exceptions import APIExceptionHandler
-from memos.api.middleware.request_context import RequestContextMiddleware
-from memos.api.routers.product_router import router as product_router
-
-
-# Configure logging
-logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
-logger = logging.getLogger(__name__)
-
-app = FastAPI(
- title="MemOS Product REST APIs",
- description="A REST API for managing multiple users with MemOS Product.",
- version="1.0.1",
-)
-
-app.add_middleware(RequestContextMiddleware, source="product_api")
-# Include routers
-app.include_router(product_router)
-
-# Exception handlers
-app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler)
-app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler)
-
-
-if __name__ == "__main__":
- import argparse
-
- import uvicorn
-
- parser = argparse.ArgumentParser()
- parser.add_argument("--port", type=int, default=8001)
- parser.add_argument("--workers", type=int, default=1)
- args = parser.parse_args()
- uvicorn.run("memos.api.product_api:app", host="0.0.0.0", port=args.port, workers=args.workers)
diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py
index 59f7f74c8..78dcfc797 100644
--- a/src/memos/api/product_models.py
+++ b/src/memos/api/product_models.py
@@ -854,10 +854,22 @@ class GetMemoryDashboardRequest(GetMemoryRequest):
class DeleteMemoryRequest(BaseRequest):
"""Request model for deleting memories."""
- writable_cube_ids: list[str] = Field(None, description="Writable cube IDs")
+ writable_cube_ids: list[str] | None = Field(None, description="Writable cube IDs")
memory_ids: list[str] | None = Field(None, description="Memory IDs")
file_ids: list[str] | None = Field(None, description="File IDs")
filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
+ user_id: str | None = Field(
+ None,
+ description="Quick delete condition: remove memories for this user_id.",
+ )
+ session_id: str | None = Field(
+ None,
+ description="Quick delete condition: remove memories for this session_id.",
+ )
+ conversation_id: str | None = Field(
+ None,
+ description="Alias of session_id for backward compatibility.",
+ )
auto_cleanup_working: bool | None = Field(
False,
description=(
@@ -866,6 +878,15 @@ class DeleteMemoryRequest(BaseRequest):
),
)
+ @model_validator(mode="after")
+ def normalize_session_alias(self) -> "DeleteMemoryRequest":
+ """Normalize conversation_id to session_id."""
+ if self.conversation_id and self.session_id and self.conversation_id != self.session_id:
+ raise ValueError("conversation_id and session_id must be the same when both are set")
+ if self.session_id is None and self.conversation_id is not None:
+ self.session_id = self.conversation_id
+ return self
+
class SuggestionRequest(BaseRequest):
"""Request model for getting suggestion queries."""
diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py
deleted file mode 100644
index 609d61124..000000000
--- a/src/memos/api/routers/product_router.py
+++ /dev/null
@@ -1,477 +0,0 @@
-import json
-import time
-import traceback
-
-from fastapi import APIRouter, HTTPException
-from fastapi.responses import StreamingResponse
-
-from memos.api.config import APIConfig
-from memos.api.product_models import (
- BaseResponse,
- ChatCompleteRequest,
- ChatRequest,
- GetMemoryPlaygroundRequest,
- MemoryCreateRequest,
- MemoryResponse,
- SearchRequest,
- SearchResponse,
- SimpleResponse,
- SuggestionRequest,
- SuggestionResponse,
- UserRegisterRequest,
- UserRegisterResponse,
-)
-from memos.configs.mem_os import MOSConfig
-from memos.log import get_logger
-from memos.mem_os.product import MOSProduct
-from memos.memos_tools.notification_service import get_error_bot_function, get_online_bot_function
-
-
-logger = get_logger(__name__)
-
-router = APIRouter(prefix="/product", tags=["Product API"])
-
-# Initialize MOSProduct instance with lazy initialization
-MOS_PRODUCT_INSTANCE = None
-
-
-def get_mos_product_instance():
- """Get or create MOSProduct instance."""
- global MOS_PRODUCT_INSTANCE
- if MOS_PRODUCT_INSTANCE is None:
- default_config = APIConfig.get_product_default_config()
- logger.info(f"*********init_default_mos_config********* {default_config}")
- from memos.configs.mem_os import MOSConfig
-
- mos_config = MOSConfig(**default_config)
-
- # Get default cube config from APIConfig (may be None if disabled)
- default_cube_config = APIConfig.get_default_cube_config()
- logger.info(f"*********initdefault_cube_config******** {default_cube_config}")
-
- # Get DingDing bot functions
- dingding_enabled = APIConfig.is_dingding_bot_enabled()
- online_bot = get_online_bot_function() if dingding_enabled else None
- error_bot = get_error_bot_function() if dingding_enabled else None
-
- MOS_PRODUCT_INSTANCE = MOSProduct(
- default_config=mos_config,
- default_cube_config=default_cube_config,
- online_bot=online_bot,
- error_bot=error_bot,
- )
- logger.info("MOSProduct instance created successfully with inheritance architecture")
- return MOS_PRODUCT_INSTANCE
-
-
-get_mos_product_instance()
-
-
-@router.post("/configure", summary="Configure MOSProduct", response_model=SimpleResponse)
-def set_config(config):
- """Set MOSProduct configuration."""
- global MOS_PRODUCT_INSTANCE
- MOS_PRODUCT_INSTANCE = MOSProduct(default_config=config)
- return SimpleResponse(message="Configuration set successfully")
-
-
-@router.post("/users/register", summary="Register a new user", response_model=UserRegisterResponse)
-def register_user(user_req: UserRegisterRequest):
- """Register a new user with configuration and default cube."""
- try:
- # Get configuration for the user
- time_start_register = time.time()
- user_config, default_mem_cube = APIConfig.create_user_config(
- user_name=user_req.user_id, user_id=user_req.user_id
- )
- logger.info(f"user_config: {user_config.model_dump(mode='json')}")
- logger.info(f"default_mem_cube: {default_mem_cube.config.model_dump(mode='json')}")
- logger.info(
- f"time register api : create user config time user_id: {user_req.user_id} time is: {time.time() - time_start_register}"
- )
- mos_product = get_mos_product_instance()
-
- # Register user with default config and mem cube
- result = mos_product.user_register(
- user_id=user_req.user_id,
- user_name=user_req.user_name,
- interests=user_req.interests,
- config=user_config,
- default_mem_cube=default_mem_cube,
- mem_cube_id=user_req.mem_cube_id,
- )
- logger.info(
- f"time register api : register time user_id: {user_req.user_id} time is: {time.time() - time_start_register}"
- )
- if result["status"] == "success":
- return UserRegisterResponse(
- message="User registered successfully",
- data={"user_id": result["user_id"], "mem_cube_id": result["default_cube_id"]},
- )
- else:
- raise HTTPException(status_code=400, detail=result["message"])
-
- except Exception as err:
- logger.error(f"Failed to register user: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.get(
- "/suggestions/{user_id}", summary="Get suggestion queries", response_model=SuggestionResponse
-)
-def get_suggestion_queries(user_id: str):
- """Get suggestion queries for a specific user."""
- try:
- mos_product = get_mos_product_instance()
- suggestions = mos_product.get_suggestion_query(user_id)
- return SuggestionResponse(
- message="Suggestions retrieved successfully", data={"query": suggestions}
- )
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to get suggestions: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.post(
- "/suggestions",
- summary="Get suggestion queries with language",
- response_model=SuggestionResponse,
-)
-def get_suggestion_queries_post(suggestion_req: SuggestionRequest):
- """Get suggestion queries for a specific user with language preference."""
- try:
- mos_product = get_mos_product_instance()
- suggestions = mos_product.get_suggestion_query(
- user_id=suggestion_req.user_id,
- language=suggestion_req.language,
- message=suggestion_req.message,
- )
- return SuggestionResponse(
- message="Suggestions retrieved successfully", data={"query": suggestions}
- )
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to get suggestions: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse)
-def get_all_memories(memory_req: GetMemoryPlaygroundRequest):
- """Get all memories for a specific user."""
- try:
- mos_product = get_mos_product_instance()
- if memory_req.search_query:
- result = mos_product.get_subgraph(
- user_id=memory_req.user_id,
- query=memory_req.search_query,
- mem_cube_ids=memory_req.mem_cube_ids,
- )
- return MemoryResponse(message="Memories retrieved successfully", data=result)
- else:
- result = mos_product.get_all(
- user_id=memory_req.user_id,
- memory_type=memory_req.memory_type,
- mem_cube_ids=memory_req.mem_cube_ids,
- )
- return MemoryResponse(message="Memories retrieved successfully", data=result)
-
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to get memories: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.post("/add", summary="add a new memory", response_model=SimpleResponse)
-def create_memory(memory_req: MemoryCreateRequest):
- """Create a new memory for a specific user."""
- logger.info("DIAGNOSTIC: /product/add endpoint called. This confirms the new code is deployed.")
- # Initialize status_tracker outside try block to avoid NameError in except blocks
- status_tracker = None
-
- try:
- time_start_add = time.time()
- mos_product = get_mos_product_instance()
-
- # Track task if task_id is provided
- item_id: str | None = None
- if (
- memory_req.task_id
- and hasattr(mos_product, "mem_scheduler")
- and mos_product.mem_scheduler
- ):
- from uuid import uuid4
-
- from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
-
- item_id = str(uuid4()) # Generate a unique item_id for this submission
-
- # Get Redis client from scheduler
- if (
- hasattr(mos_product.mem_scheduler, "redis_client")
- and mos_product.mem_scheduler.redis_client
- ):
- status_tracker = TaskStatusTracker(mos_product.mem_scheduler.redis_client)
- # Submit task with "product_add" type
- status_tracker.task_submitted(
- task_id=item_id, # Use generated item_id for internal tracking
- user_id=memory_req.user_id,
- task_type="product_add",
- mem_cube_id=memory_req.mem_cube_id or memory_req.user_id,
- business_task_id=memory_req.task_id, # Use memory_req.task_id as business_task_id
- )
- status_tracker.task_started(item_id, memory_req.user_id) # Use item_id here
-
- # Execute the add operation
- mos_product.add(
- user_id=memory_req.user_id,
- memory_content=memory_req.memory_content,
- messages=memory_req.messages,
- doc_path=memory_req.doc_path,
- mem_cube_id=memory_req.mem_cube_id,
- source=memory_req.source,
- user_profile=memory_req.user_profile,
- session_id=memory_req.session_id,
- task_id=memory_req.task_id,
- )
-
- # Mark task as completed
- if status_tracker and item_id:
- status_tracker.task_completed(item_id, memory_req.user_id)
-
- logger.info(
- f"time add api : add time user_id: {memory_req.user_id} time is: {time.time() - time_start_add}"
- )
- return SimpleResponse(message="Memory created successfully")
-
- except ValueError as err:
- # Mark task as failed if tracking
- if status_tracker and item_id:
- status_tracker.task_failed(item_id, memory_req.user_id, str(err))
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- # Mark task as failed if tracking
- if status_tracker and item_id:
- status_tracker.task_failed(item_id, memory_req.user_id, str(err))
- logger.error(f"Failed to create memory: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.post("/search", summary="Search memories", response_model=SearchResponse)
-def search_memories(search_req: SearchRequest):
- """Search memories for a specific user."""
- try:
- time_start_search = time.time()
- mos_product = get_mos_product_instance()
- result = mos_product.search(
- query=search_req.query,
- user_id=search_req.user_id,
- install_cube_ids=[search_req.mem_cube_id] if search_req.mem_cube_id else None,
- top_k=search_req.top_k,
- session_id=search_req.session_id,
- )
- logger.info(
- f"time search api : add time user_id: {search_req.user_id} time is: {time.time() - time_start_search}"
- )
- return SearchResponse(message="Search completed successfully", data=result)
-
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to search memories: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.post("/chat", summary="Chat with MemOS")
-def chat(chat_req: ChatRequest):
- """Chat with MemOS for a specific user. Returns SSE stream."""
- try:
- mos_product = get_mos_product_instance()
-
- def generate_chat_response():
- """Generate chat response as SSE stream."""
- try:
- # Directly yield from the generator without async wrapper
- yield from mos_product.chat_with_references(
- query=chat_req.query,
- user_id=chat_req.user_id,
- cube_id=chat_req.mem_cube_id,
- history=chat_req.history,
- internet_search=chat_req.internet_search,
- moscube=chat_req.moscube,
- session_id=chat_req.session_id,
- )
-
- except Exception as e:
- logger.error(f"Error in chat stream: {e}")
- error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n"
- yield error_data
-
- return StreamingResponse(
- generate_chat_response(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "Content-Type": "text/event-stream",
- "Access-Control-Allow-Origin": "*",
- "Access-Control-Allow-Headers": "*",
- "Access-Control-Allow-Methods": "*",
- },
- )
-
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to start chat: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.post("/chat/complete", summary="Chat with MemOS (Complete Response)")
-def chat_complete(chat_req: ChatCompleteRequest):
- """Chat with MemOS for a specific user. Returns complete response (non-streaming)."""
- try:
- mos_product = get_mos_product_instance()
-
- # Collect all responses from the generator
- content, references = mos_product.chat(
- query=chat_req.query,
- user_id=chat_req.user_id,
- cube_id=chat_req.mem_cube_id,
- history=chat_req.history,
- internet_search=chat_req.internet_search,
- moscube=chat_req.moscube,
- base_prompt=chat_req.base_prompt or chat_req.system_prompt,
- # will deprecate base_prompt in the future
- top_k=chat_req.top_k,
- threshold=chat_req.threshold,
- session_id=chat_req.session_id,
- )
-
- # Return the complete response
- return {
- "message": "Chat completed successfully",
- "data": {"response": content, "references": references},
- }
-
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to start chat: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.get("/users", summary="List all users", response_model=BaseResponse[list])
-def list_users():
- """List all registered users."""
- try:
- mos_product = get_mos_product_instance()
- users = mos_product.list_users()
- return BaseResponse(message="Users retrieved successfully", data=users)
- except Exception as err:
- logger.error(f"Failed to list users: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.get("/users/{user_id}", summary="Get user info", response_model=BaseResponse[dict])
-async def get_user_info(user_id: str):
- """Get user information including accessible cubes."""
- try:
- mos_product = get_mos_product_instance()
- user_info = mos_product.get_user_info(user_id)
- return BaseResponse(message="User info retrieved successfully", data=user_info)
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to get user info: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.get(
- "/configure/{user_id}", summary="Get MOSProduct configuration", response_model=SimpleResponse
-)
-def get_config(user_id: str):
- """Get MOSProduct configuration."""
- global MOS_PRODUCT_INSTANCE
- config = MOS_PRODUCT_INSTANCE.default_config
- return SimpleResponse(message="Configuration retrieved successfully", data=config)
-
-
-@router.get(
- "/users/{user_id}/config", summary="Get user configuration", response_model=BaseResponse[dict]
-)
-def get_user_config(user_id: str):
- """Get user-specific configuration."""
- try:
- mos_product = get_mos_product_instance()
- config = mos_product.get_user_config(user_id)
- if config:
- return BaseResponse(
- message="User configuration retrieved successfully",
- data=config.model_dump(mode="json"),
- )
- else:
- raise HTTPException(
- status_code=404, detail=f"Configuration not found for user {user_id}"
- )
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to get user config: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.put(
- "/users/{user_id}/config", summary="Update user configuration", response_model=SimpleResponse
-)
-def update_user_config(user_id: str, config_data: dict):
- """Update user-specific configuration."""
- try:
- mos_product = get_mos_product_instance()
-
- # Create MOSConfig from the provided data
- config = MOSConfig(**config_data)
-
- # Update the configuration
- success = mos_product.update_user_config(user_id, config)
- if success:
- return SimpleResponse(message="User configuration updated successfully")
- else:
- raise HTTPException(status_code=500, detail="Failed to update user configuration")
-
- except ValueError as err:
- raise HTTPException(status_code=400, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to update user config: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.get(
- "/instances/status", summary="Get user configuration status", response_model=BaseResponse[dict]
-)
-def get_instance_status():
- """Get information about active user configurations in memory."""
- try:
- mos_product = get_mos_product_instance()
- status_info = mos_product.get_user_instance_info()
- return BaseResponse(
- message="User configuration status retrieved successfully", data=status_info
- )
- except Exception as err:
- logger.error(f"Failed to get user configuration status: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.get("/instances/count", summary="Get active user count", response_model=BaseResponse[int])
-def get_active_user_count():
- """Get the number of active user configurations in memory."""
- try:
- mos_product = get_mos_product_instance()
- count = mos_product.get_active_user_count()
- return BaseResponse(message="Active user count retrieved successfully", data=count)
- except Exception as err:
- logger.error(f"Failed to get active user count: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py
index 529a709a4..1f1e6ccde 100644
--- a/src/memos/api/server_api.py
+++ b/src/memos/api/server_api.py
@@ -29,6 +29,17 @@
# Include routers
app.include_router(server_router)
+
+@app.get("/health")
+def health_check():
+ """Container and load balancer health endpoint."""
+ return {
+ "status": "healthy",
+ "service": "memos",
+ "version": app.version,
+ }
+
+
# Request validation failed
app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler)
# Invalid business code parameters
diff --git a/src/memos/api/start_api.py b/src/memos/api/start_api.py
deleted file mode 100644
index 24a36f017..000000000
--- a/src/memos/api/start_api.py
+++ /dev/null
@@ -1,433 +0,0 @@
-import logging
-import os
-
-from typing import Any, Generic, TypeVar
-
-from dotenv import load_dotenv
-from fastapi import FastAPI
-from fastapi.requests import Request
-from fastapi.responses import JSONResponse, RedirectResponse
-from pydantic import BaseModel, Field
-
-from memos.api.middleware.request_context import RequestContextMiddleware
-from memos.configs.mem_os import MOSConfig
-from memos.mem_os.main import MOS
-from memos.mem_user.user_manager import UserManager, UserRole
-
-
-# Configure logging
-logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
-logger = logging.getLogger(__name__)
-
-# Load environment variables
-load_dotenv(override=True)
-
-T = TypeVar("T")
-
-# Default configuration
-DEFAULT_CONFIG = {
- "user_id": os.getenv("MOS_USER_ID", "default_user"),
- "session_id": os.getenv("MOS_SESSION_ID", "default_session"),
- "enable_textual_memory": True,
- "enable_activation_memory": False,
- "top_k": int(os.getenv("MOS_TOP_K", "5")),
- "chat_model": {
- "backend": os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai"),
- "config": {
- "model_name_or_path": os.getenv("MOS_CHAT_MODEL", "gpt-3.5-turbo"),
- "api_key": os.getenv("OPENAI_API_KEY", "apikey"),
- "temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.7")),
- "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
- },
- },
-}
-
-# Initialize MOS instance with lazy initialization
-MOS_INSTANCE = None
-
-
-def get_mos_instance():
- """Get or create MOS instance with default user creation."""
- global MOS_INSTANCE
- if MOS_INSTANCE is None:
- # Create a temporary MOS instance to access user manager
- temp_config = MOSConfig(**DEFAULT_CONFIG)
- temp_mos = MOS.__new__(MOS)
- temp_mos.config = temp_config
- temp_mos.user_id = temp_config.user_id
- temp_mos.session_id = temp_config.session_id
- temp_mos.mem_cubes = {}
- temp_mos.chat_llm = None # Will be initialized later
- temp_mos.user_manager = UserManager()
-
- # Create default user if it doesn't exist
- if not temp_mos.user_manager.validate_user(temp_config.user_id):
- temp_mos.user_manager.create_user(
- user_name=temp_config.user_id, role=UserRole.USER, user_id=temp_config.user_id
- )
- logger.info(f"Created default user: {temp_config.user_id}")
-
- # Now create the actual MOS instance
- MOS_INSTANCE = MOS(config=temp_config)
-
- return MOS_INSTANCE
-
-
-app = FastAPI(
- title="MemOS REST APIs",
- description="A REST API for managing and searching memories using MemOS.",
- version="1.0.0",
-)
-
-app.add_middleware(RequestContextMiddleware)
-
-
-class BaseRequest(BaseModel):
- """Base model for all requests."""
-
- user_id: str | None = Field(
- None, description="User ID for the request", json_schema_extra={"example": "user123"}
- )
-
-
-class BaseResponse(BaseModel, Generic[T]):
- """Base model for all responses."""
-
- code: int = Field(200, description="Response status code", json_schema_extra={"example": 200})
- message: str = Field(
- ..., description="Response message", json_schema_extra={"example": "Operation successful"}
- )
- data: T | None = Field(None, description="Response data")
-
-
-class Message(BaseModel):
- role: str = Field(
- ...,
- description="Role of the message (user or assistant).",
- json_schema_extra={"example": "user"},
- )
- content: str = Field(
- ...,
- description="Message content.",
- json_schema_extra={"example": "Hello, how can I help you?"},
- )
-
-
-class MemoryCreate(BaseRequest):
- messages: list[Message] | None = Field(
- None,
- description="List of messages to store.",
- json_schema_extra={"example": [{"role": "user", "content": "Hello"}]},
- )
- mem_cube_id: str | None = Field(
- None, description="ID of the memory cube", json_schema_extra={"example": "cube123"}
- )
- memory_content: str | None = Field(
- None,
- description="Content to store as memory",
- json_schema_extra={"example": "This is a memory content"},
- )
- doc_path: str | None = Field(
- None,
- description="Path to document to store",
- json_schema_extra={"example": "/path/to/document.txt"},
- )
-
-
-class SearchRequest(BaseRequest):
- query: str = Field(
- ...,
- description="Search query.",
- json_schema_extra={"example": "How to implement a feature?"},
- )
- install_cube_ids: list[str] | None = Field(
- None,
- description="List of cube IDs to search in",
- json_schema_extra={"example": ["cube123", "cube456"]},
- )
-
-
-class MemCubeRegister(BaseRequest):
- mem_cube_name_or_path: str = Field(
- ...,
- description="Name or path of the MemCube to register.",
- json_schema_extra={"example": "/path/to/cube"},
- )
- mem_cube_id: str | None = Field(
- None, description="ID for the MemCube", json_schema_extra={"example": "cube123"}
- )
-
-
-class ChatRequest(BaseRequest):
- query: str = Field(
- ...,
- description="Chat query message.",
- json_schema_extra={"example": "What is the latest update?"},
- )
-
-
-class UserCreate(BaseRequest):
- user_name: str | None = Field(
- None, description="Name of the user", json_schema_extra={"example": "john_doe"}
- )
- role: str = Field("user", description="Role of the user", json_schema_extra={"example": "user"})
- user_id: str = Field(..., description="User ID", json_schema_extra={"example": "user123"})
-
-
-class CubeShare(BaseRequest):
- target_user_id: str = Field(
- ..., description="Target user ID to share with", json_schema_extra={"example": "user456"}
- )
-
-
-class SimpleResponse(BaseResponse[None]):
- """Simple response model for operations without data return."""
-
-
-class ConfigResponse(BaseResponse[None]):
- """Response model for configuration endpoint."""
-
-
-class MemoryResponse(BaseResponse[dict]):
- """Response model for memory operations."""
-
-
-class SearchResponse(BaseResponse[dict]):
- """Response model for search operations."""
-
-
-class ChatResponse(BaseResponse[str]):
- """Response model for chat operations."""
-
-
-class UserResponse(BaseResponse[dict]):
- """Response model for user operations."""
-
-
-class UserListResponse(BaseResponse[list]):
- """Response model for user list operations."""
-
-
-@app.post("/configure", summary="Configure MemOS", response_model=ConfigResponse)
-async def set_config(config: MOSConfig):
- """Set MemOS configuration."""
- global MOS_INSTANCE
-
- # Create a temporary user manager to check/create default user
- temp_user_manager = UserManager()
-
- # Create default user if it doesn't exist
- if not temp_user_manager.validate_user(config.user_id):
- temp_user_manager.create_user(
- user_name=config.user_id, role=UserRole.USER, user_id=config.user_id
- )
- logger.info(f"Created default user: {config.user_id}")
-
- # Now create the MOS instance
- MOS_INSTANCE = MOS(config=config)
- return ConfigResponse(message="Configuration set successfully")
-
-
-@app.post("/users", summary="Create a new user", response_model=UserResponse)
-async def create_user(user_create: UserCreate):
- """Create a new user."""
- mos_instance = get_mos_instance()
- role = UserRole(user_create.role)
- user_id = mos_instance.create_user(
- user_id=user_create.user_id, role=role, user_name=user_create.user_name
- )
- return UserResponse(message="User created successfully", data={"user_id": user_id})
-
-
-@app.get("/users", summary="List all users", response_model=UserListResponse)
-async def list_users():
- """List all active users."""
- mos_instance = get_mos_instance()
- users = mos_instance.list_users()
- return UserListResponse(message="Users retrieved successfully", data=users)
-
-
-@app.get("/users/me", summary="Get current user info", response_model=UserResponse)
-async def get_user_info():
- """Get current user information including accessible cubes."""
- mos_instance = get_mos_instance()
- user_info = mos_instance.get_user_info()
- return UserResponse(message="User info retrieved successfully", data=user_info)
-
-
-@app.post("/mem_cubes", summary="Register a MemCube", response_model=SimpleResponse)
-async def register_mem_cube(mem_cube: MemCubeRegister):
- """Register a new MemCube."""
- mos_instance = get_mos_instance()
- mos_instance.register_mem_cube(
- mem_cube_name_or_path=mem_cube.mem_cube_name_or_path,
- mem_cube_id=mem_cube.mem_cube_id,
- user_id=mem_cube.user_id,
- )
- return SimpleResponse(message="MemCube registered successfully")
-
-
-@app.delete(
- "/mem_cubes/{mem_cube_id}", summary="Unregister a MemCube", response_model=SimpleResponse
-)
-async def unregister_mem_cube(mem_cube_id: str, user_id: str | None = None):
- """Unregister a MemCube."""
- mos_instance = get_mos_instance()
- mos_instance.unregister_mem_cube(mem_cube_id=mem_cube_id, user_id=user_id)
- return SimpleResponse(message="MemCube unregistered successfully")
-
-
-@app.post(
- "/mem_cubes/{cube_id}/share",
- summary="Share a cube with another user",
- response_model=SimpleResponse,
-)
-async def share_cube(cube_id: str, share_request: CubeShare):
- """Share a cube with another user."""
- mos_instance = get_mos_instance()
- success = mos_instance.share_cube_with_user(cube_id, share_request.target_user_id)
- if success:
- return SimpleResponse(message="Cube shared successfully")
- else:
- raise ValueError("Failed to share cube")
-
-
-@app.post("/memories", summary="Create memories", response_model=SimpleResponse)
-async def add_memory(memory_create: MemoryCreate):
- """Store new memories in a MemCube."""
- if not any([memory_create.messages, memory_create.memory_content, memory_create.doc_path]):
- raise ValueError("Either messages, memory_content, or doc_path must be provided")
- mos_instance = get_mos_instance()
- if memory_create.messages:
- messages = [m.model_dump() for m in memory_create.messages]
- mos_instance.add(
- messages=messages,
- mem_cube_id=memory_create.mem_cube_id,
- user_id=memory_create.user_id,
- )
- elif memory_create.memory_content:
- mos_instance.add(
- memory_content=memory_create.memory_content,
- mem_cube_id=memory_create.mem_cube_id,
- user_id=memory_create.user_id,
- )
- elif memory_create.doc_path:
- mos_instance.add(
- doc_path=memory_create.doc_path,
- mem_cube_id=memory_create.mem_cube_id,
- user_id=memory_create.user_id,
- )
- return SimpleResponse(message="Memories added successfully")
-
-
-@app.get("/memories", summary="Get all memories", response_model=MemoryResponse)
-async def get_all_memories(
- mem_cube_id: str | None = None,
- user_id: str | None = None,
-):
- """Retrieve all memories from a MemCube."""
- mos_instance = get_mos_instance()
- result = mos_instance.get_all(mem_cube_id=mem_cube_id, user_id=user_id)
- return MemoryResponse(message="Memories retrieved successfully", data=result)
-
-
-@app.get(
- "/memories/{mem_cube_id}/{memory_id}", summary="Get a memory", response_model=MemoryResponse
-)
-async def get_memory(mem_cube_id: str, memory_id: str, user_id: str | None = None):
- """Retrieve a specific memory by ID from a MemCube."""
- mos_instance = get_mos_instance()
- result = mos_instance.get(mem_cube_id=mem_cube_id, memory_id=memory_id, user_id=user_id)
- return MemoryResponse(message="Memory retrieved successfully", data=result)
-
-
-@app.post("/search", summary="Search memories", response_model=SearchResponse)
-async def search_memories(search_req: SearchRequest):
- """Search for memories across MemCubes."""
- mos_instance = get_mos_instance()
- result = mos_instance.search(
- query=search_req.query,
- user_id=search_req.user_id,
- install_cube_ids=search_req.install_cube_ids,
- )
- return SearchResponse(message="Search completed successfully", data=result)
-
-
-@app.put(
- "/memories/{mem_cube_id}/{memory_id}", summary="Update a memory", response_model=SimpleResponse
-)
-async def update_memory(
- mem_cube_id: str, memory_id: str, updated_memory: dict[str, Any], user_id: str | None = None
-):
- """Update an existing memory in a MemCube."""
- mos_instance = get_mos_instance()
- mos_instance.update(
- mem_cube_id=mem_cube_id,
- memory_id=memory_id,
- text_memory_item=updated_memory,
- user_id=user_id,
- )
- return SimpleResponse(message="Memory updated successfully")
-
-
-@app.delete(
- "/memories/{mem_cube_id}/{memory_id}", summary="Delete a memory", response_model=SimpleResponse
-)
-async def delete_memory(mem_cube_id: str, memory_id: str, user_id: str | None = None):
- """Delete a specific memory from a MemCube."""
- mos_instance = get_mos_instance()
- mos_instance.delete(mem_cube_id=mem_cube_id, memory_id=memory_id, user_id=user_id)
- return SimpleResponse(message="Memory deleted successfully")
-
-
-@app.delete("/memories/{mem_cube_id}", summary="Delete all memories", response_model=SimpleResponse)
-async def delete_all_memories(mem_cube_id: str, user_id: str | None = None):
- """Delete all memories from a MemCube."""
- mos_instance = get_mos_instance()
- mos_instance.delete_all(mem_cube_id=mem_cube_id, user_id=user_id)
- return SimpleResponse(message="All memories deleted successfully")
-
-
-@app.post("/chat", summary="Chat with MemOS", response_model=ChatResponse)
-async def chat(chat_req: ChatRequest):
- """Chat with the MemOS system."""
- mos_instance = get_mos_instance()
- response = mos_instance.chat(query=chat_req.query, user_id=chat_req.user_id)
- if response is None:
- raise ValueError("No response generated")
- return ChatResponse(message="Chat response generated", data=response)
-
-
-@app.get("/", summary="Redirect to the OpenAPI documentation", include_in_schema=False)
-async def home():
- """Redirect to the OpenAPI documentation."""
- return RedirectResponse(url="/docs", status_code=307)
-
-
-@app.exception_handler(ValueError)
-async def value_error_handler(request: Request, exc: ValueError):
- """Handle ValueError exceptions globally."""
- return JSONResponse(
- status_code=400,
- content={"code": 400, "message": str(exc), "data": None},
- )
-
-
-@app.exception_handler(Exception)
-async def global_exception_handler(request: Request, exc: Exception):
- """Handle all unhandled exceptions globally."""
- logger.exception("Unhandled error:")
- return JSONResponse(
- status_code=500,
- content={"code": 500, "message": str(exc), "data": None},
- )
-
-
-if __name__ == "__main__":
- import argparse
-
- parser = argparse.ArgumentParser()
- parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
- parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
- parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development")
- args = parser.parse_args()
diff --git a/src/memos/cli.py b/src/memos/cli.py
index 092f2d276..2ead5ab29 100644
--- a/src/memos/cli.py
+++ b/src/memos/cli.py
@@ -11,9 +11,16 @@
from io import BytesIO
+def get_openapi_app():
+ """Return the FastAPI app used for OpenAPI export."""
+ from memos.api.server_api import app
+
+ return app
+
+
def export_openapi(output: str) -> bool:
"""Export OpenAPI schema to JSON file."""
- from memos.api.server_api import app
+ app = get_openapi_app()
# Create directory if it doesn't exist
if os.path.dirname(output):
diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py
index 5900d2357..98de09812 100644
--- a/src/memos/configs/graph_db.py
+++ b/src/memos/configs/graph_db.py
@@ -103,57 +103,6 @@ def validate_community(self):
return self
-class NebulaGraphDBConfig(BaseGraphDBConfig):
- """
- NebulaGraph-specific configuration.
-
- Key concepts:
- - `space`: Equivalent to a database or namespace. All tag/edge/schema live within a space.
- - `user_name`: Used for logical tenant isolation if needed.
- - `auto_create`: Whether to automatically create the target space if it does not exist.
-
- Example:
- ---
- hosts = ["127.0.0.1:9669"]
- user = "root"
- password = "nebula"
- space = "shared_graph"
- user_name = "alice"
- """
-
- space: str = Field(
- ..., description="The name of the target NebulaGraph space (like a database)"
- )
- user_name: str | None = Field(
- default=None,
- description="Logical user or tenant ID for data isolation (optional, used in metadata tagging)",
- )
- auto_create: bool = Field(
- default=False,
- description="Whether to auto-create the space if it does not exist",
- )
- use_multi_db: bool = Field(
- default=True,
- description=(
- "If True: use Neo4j's multi-database feature for physical isolation; "
- "each user typically gets a separate database. "
- "If False: use a single shared database with logical isolation by user_name."
- ),
- )
- max_client: int = Field(
- default=1000,
- description=("max_client"),
- )
- embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding")
-
- @model_validator(mode="after")
- def validate_config(self):
- """Validate config."""
- if not self.space:
- raise ValueError("`space` must be provided")
- return self
-
-
class PolarDBGraphDBConfig(BaseConfig):
"""
PolarDB-specific configuration.
@@ -299,7 +248,6 @@ class GraphDBConfigFactory(BaseModel):
backend_to_class: ClassVar[dict[str, Any]] = {
"neo4j": Neo4jGraphDBConfig,
"neo4j-community": Neo4jCommunityGraphDBConfig,
- "nebular": NebulaGraphDBConfig,
"polardb": PolarDBGraphDBConfig,
"postgres": PostgresGraphDBConfig,
}
diff --git a/src/memos/configs/internet_retriever.py b/src/memos/configs/internet_retriever.py
index 1c5e2b8ad..562cfdd1f 100644
--- a/src/memos/configs/internet_retriever.py
+++ b/src/memos/configs/internet_retriever.py
@@ -67,6 +67,19 @@ class BochaSearchConfig(BaseInternetRetrieverConfig):
)
+class TavilySearchConfig(BaseInternetRetrieverConfig):
+ """Configuration class for Tavily Search API."""
+
+ search_engine_id: str | None = Field(
+ None, description="Not used for Tavily Search (kept for compatibility)"
+ )
+ max_results: int = Field(default=10, description="Maximum number of results to retrieve")
+ search_depth: str = Field(default="basic", description="Search depth: 'basic' or 'advanced'")
+ include_answer: bool = Field(
+ default=False, description="Whether to include an AI-generated answer"
+ )
+
+
class InternetRetrieverConfigFactory(BaseConfig):
"""Factory class for creating internet retriever configurations."""
@@ -82,6 +95,7 @@ class InternetRetrieverConfigFactory(BaseConfig):
"bing": BingSearchConfig,
"xinyu": XinyuSearchConfig,
"bocha": BochaSearchConfig,
+ "tavily": TavilySearchConfig,
}
@field_validator("backend")
diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py
index 11c39b33c..81f7038fa 100644
--- a/src/memos/configs/llm.py
+++ b/src/memos/configs/llm.py
@@ -72,6 +72,15 @@ class DeepSeekLLMConfig(OpenAILLMConfig):
)
+class MinimaxLLMConfig(OpenAILLMConfig):
+ api_key: str = Field(..., description="API key for MiniMax")
+ api_base: str = Field(
+ default="https://api.minimax.io/v1",
+ description="Base URL for MiniMax OpenAI-compatible API",
+ )
+ extra_body: Any = Field(default=None, description="Extra options for API")
+
+
class AzureLLMConfig(BaseLLMConfig):
base_url: str = Field(
default="https://api.openai.azure.com/",
@@ -146,6 +155,7 @@ class LLMConfigFactory(BaseConfig):
"huggingface_singleton": HFLLMConfig, # Add singleton support
"qwen": QwenLLMConfig,
"deepseek": DeepSeekLLMConfig,
+ "minimax": MinimaxLLMConfig,
"openai_new": OpenAIResponsesLLMConfig,
}
diff --git a/src/memos/context/context.py b/src/memos/context/context.py
index 5c8401732..5347de880 100644
--- a/src/memos/context/context.py
+++ b/src/memos/context/context.py
@@ -155,7 +155,7 @@ def get_current_user_name() -> str | None:
def get_current_source() -> str | None:
"""
- Get the current request's source (e.g., 'product_api' or 'server_api').
+ Get the current request's source (for example, 'server_api').
"""
context = _request_context.get()
if context:
diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py
index c207e3190..93b5971ec 100644
--- a/src/memos/graph_dbs/factory.py
+++ b/src/memos/graph_dbs/factory.py
@@ -2,7 +2,6 @@
from memos.configs.graph_db import GraphDBConfigFactory
from memos.graph_dbs.base import BaseGraphDB
-from memos.graph_dbs.nebular import NebulaGraphDB
from memos.graph_dbs.neo4j import Neo4jGraphDB
from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB
from memos.graph_dbs.polardb import PolarDBGraphDB
@@ -15,7 +14,6 @@ class GraphStoreFactory(BaseGraphDB):
backend_to_class: ClassVar[dict[str, Any]] = {
"neo4j": Neo4jGraphDB,
"neo4j-community": Neo4jCommunityGraphDB,
- "nebular": NebulaGraphDB,
"polardb": PolarDBGraphDB,
"postgres": PostgresGraphDB,
}
diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py
deleted file mode 100644
index 428d6d09e..000000000
--- a/src/memos/graph_dbs/nebular.py
+++ /dev/null
@@ -1,1794 +0,0 @@
-import json
-import traceback
-
-from contextlib import suppress
-from datetime import datetime
-from threading import Lock
-from typing import TYPE_CHECKING, Any, ClassVar, Literal
-
-import numpy as np
-
-from memos.configs.graph_db import NebulaGraphDBConfig
-from memos.dependency import require_python_package
-from memos.graph_dbs.base import BaseGraphDB
-from memos.log import get_logger
-from memos.utils import timed
-
-
-if TYPE_CHECKING:
- from nebulagraph_python import (
- NebulaClient,
- )
-
-
-logger = get_logger(__name__)
-
-
-_TRANSIENT_ERR_KEYS = (
- "Session not found",
- "Connection not established",
- "timeout",
- "deadline exceeded",
- "Broken pipe",
- "EOFError",
- "socket closed",
- "connection reset",
- "connection refused",
-)
-
-
-@timed
-def _normalize(vec: list[float]) -> list[float]:
- v = np.asarray(vec, dtype=np.float32)
- norm = np.linalg.norm(v)
- return (v / (norm if norm else 1.0)).tolist()
-
-
-@timed
-def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
- node_id = item["id"]
- memory = item["memory"]
- metadata = item.get("metadata", {})
- return node_id, memory, metadata
-
-
-@timed
-def _escape_str(value: str) -> str:
- out = []
- for ch in value:
- code = ord(ch)
- if ch == "\\":
- out.append("\\\\")
- elif ch == '"':
- out.append('\\"')
- elif ch == "\n":
- out.append("\\n")
- elif ch == "\r":
- out.append("\\r")
- elif ch == "\t":
- out.append("\\t")
- elif ch == "\b":
- out.append("\\b")
- elif ch == "\f":
- out.append("\\f")
- elif code < 0x20 or code in (0x2028, 0x2029):
- out.append(f"\\u{code:04x}")
- else:
- out.append(ch)
- return "".join(out)
-
-
-@timed
-def _format_datetime(value: str | datetime) -> str:
- """Ensure datetime is in ISO 8601 format string."""
- if isinstance(value, datetime):
- return value.isoformat()
- return str(value)
-
-
-@timed
-def _normalize_datetime(val):
- """
- Normalize datetime to ISO 8601 UTC string with +00:00.
- - If val is datetime object -> keep isoformat() (Neo4j)
- - If val is string without timezone -> append +00:00 (Nebula)
- - Otherwise just str()
- """
- if hasattr(val, "isoformat"):
- return val.isoformat()
- if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")):
- return val + "+08:00"
- return str(val)
-
-
-class NebulaGraphDB(BaseGraphDB):
- """
- NebulaGraph-based implementation of a graph memory store.
- """
-
- # ====== shared pool cache & refcount ======
- # These are process-local; in a multi-process model each process will
- # have its own cache.
- _CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {}
- _CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {}
- _CLIENT_LOCK: ClassVar[Lock] = Lock()
- _CLIENT_INIT_DONE: ClassVar[set[str]] = set()
-
- @staticmethod
- def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]:
- hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None)
- if isinstance(hosts, str):
- return [hosts]
- return list(hosts or [])
-
- @staticmethod
- def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
- hosts = NebulaGraphDB._get_hosts_from_cfg(cfg)
- return "|".join(
- [
- "nebula-sync",
- ",".join(hosts),
- str(getattr(cfg, "user", "")),
- str(getattr(cfg, "space", "")),
- ]
- )
-
- @classmethod
- def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB":
- tmp = object.__new__(NebulaGraphDB)
- tmp.config = cfg
- tmp.db_name = cfg.space
- tmp.user_name = None
- tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072)
- tmp.default_memory_dimension = 3072
- tmp.common_fields = {
- "id",
- "memory",
- "user_name",
- "user_id",
- "session_id",
- "status",
- "key",
- "confidence",
- "tags",
- "created_at",
- "updated_at",
- "memory_type",
- "sources",
- "source",
- "node_type",
- "visibility",
- "usage",
- "background",
- }
- tmp.base_fields = set(tmp.common_fields) - {"usage"}
- tmp.heavy_fields = {"usage"}
- tmp.dim_field = (
- f"embedding_{tmp.embedding_dimension}"
- if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension)
- else "embedding"
- )
- tmp.system_db_name = cfg.space
- tmp._client = client
- tmp._owns_client = False
- return tmp
-
- @classmethod
- def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]:
- from nebulagraph_python import (
- ConnectionConfig,
- NebulaClient,
- SessionConfig,
- SessionPoolConfig,
- )
-
- key = cls._make_client_key(cfg)
- with cls._CLIENT_LOCK:
- client = cls._CLIENT_CACHE.get(key)
- if client is None:
- # Connection setting
-
- tmp_client = NebulaClient(
- hosts=cfg.uri,
- username=cfg.user,
- password=cfg.password,
- session_config=SessionConfig(graph=None),
- session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000),
- )
- try:
- cls._ensure_space_exists(tmp_client, cfg)
- finally:
- tmp_client.close()
-
- conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
- if conn_conf is None:
- conn_conf = ConnectionConfig.from_defults(
- cls._get_hosts_from_cfg(cfg),
- getattr(cfg, "ssl_param", None),
- )
-
- sess_conf = SessionConfig(graph=getattr(cfg, "space", None))
- pool_conf = SessionPoolConfig(
- size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000
- )
-
- client = NebulaClient(
- hosts=conn_conf.hosts,
- username=cfg.user,
- password=cfg.password,
- conn_config=conn_conf,
- session_config=sess_conf,
- session_pool_config=pool_conf,
- )
- cls._CLIENT_CACHE[key] = client
- cls._CLIENT_REFCOUNT[key] = 0
- logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}")
-
- cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1
-
- if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
- try:
- pass
- finally:
- pass
-
- if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
- with cls._CLIENT_LOCK:
- if key not in cls._CLIENT_INIT_DONE:
- admin = cls._bootstrap_admin(cfg, client)
- try:
- admin._ensure_database_exists()
- admin._create_basic_property_indexes()
- admin._create_vector_index(
- dimensions=int(
- admin.embedding_dimension or admin.default_memory_dimension
- ),
- )
- cls._CLIENT_INIT_DONE.add(key)
- logger.info("[NebulaGraphDBSync] One-time init done")
- except Exception:
- logger.exception("[NebulaGraphDBSync] One-time init failed")
-
- return key, client
-
- def _refresh_client(self):
- """
- refresh NebulaClient:
- """
- old_key = getattr(self, "_client_key", None)
- if not old_key:
- return
-
- cls = self.__class__
- with cls._CLIENT_LOCK:
- try:
- if old_key in cls._CLIENT_CACHE:
- try:
- cls._CLIENT_CACHE[old_key].close()
- except Exception as e:
- logger.warning(f"[refresh_client] close old client error: {e}")
- finally:
- cls._CLIENT_CACHE.pop(old_key, None)
- finally:
- cls._CLIENT_REFCOUNT[old_key] = 0
-
- new_key, new_client = cls._get_or_create_shared_client(self.config)
- self._client_key = new_key
- self._client = new_client
- logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}")
-
- @classmethod
- def _release_shared_client(cls, key: str):
- with cls._CLIENT_LOCK:
- if key not in cls._CLIENT_CACHE:
- return
- cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1)
- if cls._CLIENT_REFCOUNT[key] == 0:
- try:
- cls._CLIENT_CACHE[key].close()
- except Exception as e:
- logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}")
- finally:
- cls._CLIENT_CACHE.pop(key, None)
- cls._CLIENT_REFCOUNT.pop(key, None)
- logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}")
-
- @classmethod
- def close_all_shared_clients(cls):
- with cls._CLIENT_LOCK:
- for key, client in list(cls._CLIENT_CACHE.items()):
- try:
- client.close()
- except Exception as e:
- logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}")
- finally:
- logger.info(f"[NebulaGraphDBSync] Closed client key={key}")
- cls._CLIENT_CACHE.clear()
- cls._CLIENT_REFCOUNT.clear()
-
- @require_python_package(
- import_name="nebulagraph_python",
- install_command="pip install nebulagraph-python>=5.1.1",
- install_link=".....",
- )
- def __init__(self, config: NebulaGraphDBConfig):
- """
- NebulaGraph DB client initialization.
-
- Required config attributes:
- - hosts: list[str] like ["host1:port", "host2:port"]
- - user: str
- - password: str
- - db_name: str (optional for basic commands)
-
- Example config:
- {
- "hosts": ["xxx.xx.xx.xxx:xxxx"],
- "user": "root",
- "password": "nebula",
- "space": "test"
- }
- """
-
- assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED"
- self.config = config
- self.db_name = config.space
- self.user_name = config.user_name
- self.embedding_dimension = config.embedding_dimension
- self.default_memory_dimension = 3072
- self.common_fields = {
- "id",
- "memory",
- "user_name",
- "user_id",
- "session_id",
- "status",
- "key",
- "confidence",
- "tags",
- "created_at",
- "updated_at",
- "memory_type",
- "sources",
- "source",
- "node_type",
- "visibility",
- "usage",
- "background",
- }
- self.base_fields = set(self.common_fields) - {"usage"}
- self.heavy_fields = {"usage"}
- self.dim_field = (
- f"embedding_{self.embedding_dimension}"
- if (str(self.embedding_dimension) != str(self.default_memory_dimension))
- else "embedding"
- )
- self.system_db_name = config.space
-
- # ---- NEW: pool acquisition strategy
- # Get or create a shared pool from the class-level cache
- self._client_key, self._client = self._get_or_create_shared_client(config)
- self._owns_client = True
-
- logger.info("Connected to NebulaGraph successfully.")
-
- @timed
- def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True):
- def _wrap_use_db(q: str) -> str:
- if auto_set_db and self.db_name:
- return f"USE `{self.db_name}`\n{q}"
- return q
-
- try:
- return self._client.execute(_wrap_use_db(gql), timeout=timeout)
-
- except Exception as e:
- emsg = str(e)
- if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS):
- logger.warning(f"[execute_query] {e!s} โ refreshing session pool and retry once...")
- try:
- self._refresh_client()
- return self._client.execute(_wrap_use_db(gql), timeout=timeout)
- except Exception:
- logger.exception("[execute_query] retry after refresh failed")
- raise
- raise
-
- @timed
- def close(self):
- """
- Close the connection resource if this instance owns it.
-
- - If pool was injected (`shared_pool`), do nothing.
- - If pool was acquired via shared cache, decrement refcount and close
- when the last owner releases it.
- """
- if not self._owns_client:
- logger.debug("[NebulaGraphDBSync] close() skipped (injected client).")
- return
- if self._client_key:
- self._release_shared_client(self._client_key)
- self._client_key = None
- self._client = None
-
- # NOTE: __del__ is best-effort; do not rely on GC order.
- def __del__(self):
- with suppress(Exception):
- self.close()
-
- @timed
- def create_index(
- self,
- label: str = "Memory",
- vector_property: str = "embedding",
- dimensions: int = 3072,
- index_name: str = "memory_vector_index",
- ) -> None:
- # Create vector index
- self._create_vector_index(label, vector_property, dimensions, index_name)
- # Create indexes
- self._create_basic_property_indexes()
-
- @timed
- def remove_oldest_memory(
- self, memory_type: str, keep_latest: int, user_name: str | None = None
- ) -> None:
- """
- Remove all WorkingMemory nodes except the latest `keep_latest` entries.
-
- Args:
- memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
- keep_latest (int): Number of latest WorkingMemory entries to keep.
- user_name(str): optional user_name.
- """
- try:
- user_name = user_name if user_name else self.config.user_name
- optional_condition = f"AND n.user_name = '{user_name}'"
- count = self.count_nodes(memory_type, user_name)
- if count > keep_latest:
- delete_query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE n.memory_type = '{memory_type}'
- {optional_condition}
- ORDER BY n.updated_at DESC
- OFFSET {int(keep_latest)}
- DETACH DELETE n
- """
- self.execute_query(delete_query)
- except Exception as e:
- logger.warning(f"Delete old mem error: {e}")
-
- @timed
- def add_node(
- self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None
- ) -> None:
- """
- Insert or update a Memory node in NebulaGraph.
- """
- metadata["user_name"] = user_name if user_name else self.config.user_name
- now = datetime.utcnow()
- metadata = metadata.copy()
- metadata.setdefault("created_at", now)
- metadata.setdefault("updated_at", now)
- metadata["node_type"] = metadata.pop("type")
- metadata["id"] = id
- metadata["memory"] = memory
-
- if "embedding" in metadata and isinstance(metadata["embedding"], list):
- assert len(metadata["embedding"]) == self.embedding_dimension, (
- f"input embedding dimension must equal to {self.embedding_dimension}"
- )
- embedding = metadata.pop("embedding")
- metadata[self.dim_field] = _normalize(embedding)
-
- metadata = self._metadata_filter(metadata)
- properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
- gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
-
- try:
- self.execute_query(gql)
- logger.info("insert success")
- except Exception as e:
- logger.error(
- f"Failed to insert vertex {id}: gql: {gql}, {e}\ntrace: {traceback.format_exc()}"
- )
-
- @timed
- def node_not_exist(self, scope: str, user_name: str | None = None) -> int:
- user_name = user_name if user_name else self.config.user_name
- filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{user_name}"'
- query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE {filter_clause}
- RETURN n.id AS id
- LIMIT 1
- """
-
- try:
- result = self.execute_query(query)
- return result.size == 0
- except Exception as e:
- logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True)
- raise
-
- @timed
- def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None:
- """
- Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present.
- """
- user_name = user_name if user_name else self.config.user_name
- fields = fields.copy()
- set_clauses = []
- for k, v in fields.items():
- set_clauses.append(f"n.{k} = {self._format_value(v, k)}")
-
- set_clause_str = ",\n ".join(set_clauses)
-
- query = f"""
- MATCH (n@Memory {{id: "{id}"}})
- """
- query += f'WHERE n.user_name = "{user_name}"'
-
- query += f"\nSET {set_clause_str}"
- self.execute_query(query)
-
- @timed
- def delete_node(self, id: str, user_name: str | None = None) -> None:
- """
- Delete a node from the graph.
- Args:
- id: Node identifier to delete.
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- user_name = user_name if user_name else self.config.user_name
- query = f"""
- MATCH (n@Memory {{id: "{id}"}}) WHERE n.user_name = {self._format_value(user_name)}
- DETACH DELETE n
- """
- self.execute_query(query)
-
- @timed
- def add_edge(self, source_id: str, target_id: str, type: str, user_name: str | None = None):
- """
- Create an edge from source node to target node.
- Args:
- source_id: ID of the source node.
- target_id: ID of the target node.
- type: Relationship type (e.g., 'RELATE_TO', 'PARENT').
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- if not source_id or not target_id:
- raise ValueError("[add_edge] source_id and target_id must be provided")
- user_name = user_name if user_name else self.config.user_name
- props = ""
- props = f'{{user_name: "{user_name}"}}'
- insert_stmt = f'''
- MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
- INSERT (a) -[e@{type} {props}]-> (b)
- '''
- try:
- self.execute_query(insert_stmt)
- except Exception as e:
- logger.error(f"Failed to insert edge: {e}", exc_info=True)
-
- @timed
- def delete_edge(
- self, source_id: str, target_id: str, type: str, user_name: str | None = None
- ) -> None:
- """
- Delete a specific edge between two nodes.
- Args:
- source_id: ID of the source node.
- target_id: ID of the target node.
- type: Relationship type to remove.
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- user_name = user_name if user_name else self.config.user_name
- query = f"""
- MATCH (a@Memory) -[r@{type}]-> (b@Memory)
- WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)}
- """
-
- query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}"
- query += "\nDELETE r"
- self.execute_query(query)
-
- @timed
- def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int:
- user_name = user_name if user_name else self.config.user_name
- query = f"""
- MATCH (n@Memory)
- WHERE n.memory_type = "{memory_type}"
- """
- query += f"\nAND n.user_name = '{user_name}'"
- query += "\nRETURN COUNT(n) AS count"
-
- try:
- result = self.execute_query(query)
- return result.one_or_none()["count"].value
- except Exception as e:
- logger.error(f"[get_memory_count] Failed: {e}")
- return -1
-
- @timed
- def count_nodes(self, scope: str, user_name: str | None = None) -> int:
- user_name = user_name if user_name else self.config.user_name
- query = f"""
- MATCH (n@Memory)
- WHERE n.memory_type = "{scope}"
- """
- query += f"\nAND n.user_name = '{user_name}'"
- query += "\nRETURN count(n) AS count"
-
- result = self.execute_query(query)
- return result.one_or_none()["count"].value
-
- @timed
- def edge_exists(
- self,
- source_id: str,
- target_id: str,
- type: str = "ANY",
- direction: str = "OUTGOING",
- user_name: str | None = None,
- ) -> bool:
- """
- Check if an edge exists between two nodes.
- Args:
- source_id: ID of the source node.
- target_id: ID of the target node.
- type: Relationship type. Use "ANY" to match any relationship type.
- direction: Direction of the edge.
- Use "OUTGOING" (default), "INCOMING", or "ANY".
- user_name (str, optional): User name for filtering in non-multi-db mode
- Returns:
- True if the edge exists, otherwise False.
- """
- # Prepare the relationship pattern
- user_name = user_name if user_name else self.config.user_name
- rel = "r" if type == "ANY" else f"r@{type}"
-
- # Prepare the match pattern with direction
- if direction == "OUTGOING":
- pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]->(b@Memory {{id: '{target_id}'}})"
- elif direction == "INCOMING":
- pattern = f"(a@Memory {{id: '{source_id}'}})<-[{rel}]-(b@Memory {{id: '{target_id}'}})"
- elif direction == "ANY":
- pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]-(b@Memory {{id: '{target_id}'}})"
- else:
- raise ValueError(
- f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'."
- )
- query = f"MATCH {pattern}"
- query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
- query += "\nRETURN r"
-
- # Run the Cypher query
- result = self.execute_query(query)
- record = result.one_or_none()
- if record is None:
- return False
- return record.values() is not None
-
- @timed
- # Graph Query & Reasoning
- def get_node(
- self, id: str, include_embedding: bool = False, user_name: str | None = None
- ) -> dict[str, Any] | None:
- """
- Retrieve a Memory node by its unique ID.
-
- Args:
- id (str): Node ID (Memory.id)
- include_embedding: with/without embedding
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- dict: Node properties as key-value pairs, or None if not found.
- """
- filter_clause = f'n.id = "{id}"'
- return_fields = self._build_return_fields(include_embedding)
- gql = f"""
- MATCH (n@Memory)
- WHERE {filter_clause}
- RETURN {return_fields}
- """
-
- try:
- result = self.execute_query(gql)
- for row in result:
- props = {k: v.value for k, v in row.items()}
- node = self._parse_node(props)
- return node
-
- except Exception as e:
- logger.error(
- f"[get_node] Failed to retrieve node '{id}': {e}, trace: {traceback.format_exc()}"
- )
- return None
-
- @timed
- def get_nodes(
- self,
- ids: list[str],
- include_embedding: bool = False,
- user_name: str | None = None,
- **kwargs,
- ) -> list[dict[str, Any]]:
- """
- Retrieve the metadata and memory of a list of nodes.
- Args:
- ids: List of Node identifier.
- include_embedding: with/without embedding
- user_name (str, optional): User name for filtering in non-multi-db mode
- Returns:
- list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
-
- Notes:
- - Assumes all provided IDs are valid and exist.
- - Returns empty list if input is empty.
- """
- if not ids:
- return []
- # Safe formatting of the ID list
- id_list = ",".join(f'"{_id}"' for _id in ids)
-
- return_fields = self._build_return_fields(include_embedding)
- query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE n.id IN [{id_list}]
- RETURN {return_fields}
- """
- nodes = []
- try:
- results = self.execute_query(query)
- for row in results:
- props = {k: v.value for k, v in row.items()}
- nodes.append(self._parse_node(props))
- except Exception as e:
- logger.error(
- f"[get_nodes] Failed to retrieve nodes {ids}: {e}, trace: {traceback.format_exc()}"
- )
- return nodes
-
- @timed
- def get_edges(
- self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None
- ) -> list[dict[str, str]]:
- """
- Get edges connected to a node, with optional type and direction filter.
-
- Args:
- id: Node ID to retrieve edges for.
- type: Relationship type to match, or 'ANY' to match all.
- direction: 'OUTGOING', 'INCOMING', or 'ANY'.
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- List of edges:
- [
- {"from": "source_id", "to": "target_id", "type": "RELATE"},
- ...
- ]
- """
- # Build relationship type filter
- rel_type = "" if type == "ANY" else f"@{type}"
- user_name = user_name if user_name else self.config.user_name
- # Build Cypher pattern based on direction
- if direction == "OUTGOING":
- pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)"
- where_clause = f"a.id = '{id}'"
- elif direction == "INCOMING":
- pattern = f"(a@Memory)<-[r{rel_type}]-(b@Memory)"
- where_clause = f"a.id = '{id}'"
- elif direction == "ANY":
- pattern = f"(a@Memory)-[r{rel_type}]-(b@Memory)"
- where_clause = f"a.id = '{id}' OR b.id = '{id}'"
- else:
- raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.")
-
- where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
-
- query = f"""
- MATCH {pattern}
- WHERE {where_clause}
- RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
- """
-
- result = self.execute_query(query)
- edges = []
- for record in result:
- edges.append(
- {
- "from": record["from_id"].value,
- "to": record["to_id"].value,
- "type": record["edge_type"].value,
- }
- )
- return edges
-
- @timed
- def get_neighbors_by_tag(
- self,
- tags: list[str],
- exclude_ids: list[str],
- top_k: int = 5,
- min_overlap: int = 1,
- include_embedding: bool = False,
- user_name: str | None = None,
- ) -> list[dict[str, Any]]:
- """
- Find top-K neighbor nodes with maximum tag overlap.
-
- Args:
- tags: The list of tags to match.
- exclude_ids: Node IDs to exclude (e.g., local cluster).
- top_k: Max number of neighbors to return.
- min_overlap: Minimum number of overlapping tags required.
- include_embedding: with/without embedding
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- List of dicts with node details and overlap count.
- """
- if not tags:
- return []
- user_name = user_name if user_name else self.config.user_name
- where_clauses = [
- 'n.status = "activated"',
- 'NOT (n.node_type = "reasoning")',
- 'NOT (n.memory_type = "WorkingMemory")',
- ]
- if exclude_ids:
- where_clauses.append(f"NOT (n.id IN {exclude_ids})")
-
- where_clauses.append(f'n.user_name = "{user_name}"')
-
- where_clause = " AND ".join(where_clauses)
- tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
-
- return_fields = self._build_return_fields(include_embedding)
- query = f"""
- LET tag_list = {tag_list_literal}
-
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE {where_clause}
- RETURN {return_fields},
- size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
- ORDER BY overlap_count DESC
- LIMIT {top_k}
- """
-
- result = self.execute_query(query)
- neighbors: list[dict[str, Any]] = []
- for r in result:
- props = {k: v.value for k, v in r.items() if k != "overlap_count"}
- parsed = self._parse_node(props)
- parsed["overlap_count"] = r["overlap_count"].value
- neighbors.append(parsed)
-
- neighbors.sort(key=lambda x: x["overlap_count"], reverse=True)
- neighbors = neighbors[:top_k]
- result = []
- for neighbor in neighbors[:top_k]:
- neighbor.pop("overlap_count")
- result.append(neighbor)
- return result
-
- @timed
- def get_children_with_embeddings(
- self, id: str, user_name: str | None = None
- ) -> list[dict[str, Any]]:
- user_name = user_name if user_name else self.config.user_name
- where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'"
-
- query = f"""
- MATCH (p@Memory)-[@PARENT]->(c@Memory)
- WHERE p.id = "{id}" {where_user}
- RETURN c.id AS id, c.{self.dim_field} AS {self.dim_field}, c.memory AS memory
- """
- result = self.execute_query(query)
- children = []
- for row in result:
- eid = row["id"].value # STRING
- emb_v = row[self.dim_field].value # NVector
- emb = list(emb_v.values) if emb_v else []
- mem = row["memory"].value # STRING
-
- children.append({"id": eid, "embedding": emb, "memory": mem})
- return children
-
- @timed
- def get_subgraph(
- self,
- center_id: str,
- depth: int = 2,
- center_status: str = "activated",
- user_name: str | None = None,
- ) -> dict[str, Any]:
- """
- Retrieve a local subgraph centered at a given node.
- Args:
- center_id: The ID of the center node.
- depth: The hop distance for neighbors.
- center_status: Required status for center node.
- user_name (str, optional): User name for filtering in non-multi-db mode
- Returns:
- {
- "core_node": {...},
- "neighbors": [...],
- "edges": [...]
- }
- """
- if not 1 <= depth <= 5:
- raise ValueError("depth must be 1-5")
-
- user_name = user_name if user_name else self.config.user_name
-
- gql = f"""
- MATCH (center@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE center.id = '{center_id}'
- AND center.status = '{center_status}'
- AND center.user_name = '{user_name}'
- OPTIONAL MATCH p = (center)-[e]->{{1,{depth}}}(neighbor@Memory)
- WHERE neighbor.user_name = '{user_name}'
- RETURN center,
- collect(DISTINCT neighbor) AS neighbors,
- collect(EDGES(p)) AS edge_chains
- """
-
- result = self.execute_query(gql).one_or_none()
- if not result or result.size == 0:
- return {"core_node": None, "neighbors": [], "edges": []}
-
- core_node_props = result["center"].as_node().get_properties()
- core_node = self._parse_node(core_node_props)
- neighbors = []
- vid_to_id_map = {result["center"].as_node().node_id: core_node["id"]}
- for n in result["neighbors"].value:
- n_node = n.as_node()
- n_props = n_node.get_properties()
- node_parsed = self._parse_node(n_props)
- neighbors.append(node_parsed)
- vid_to_id_map[n_node.node_id] = node_parsed["id"]
-
- edges = []
- for chain_group in result["edge_chains"].value:
- for edge_wr in chain_group.value:
- edge = edge_wr.value
- edges.append(
- {
- "type": edge.get_type(),
- "source": vid_to_id_map.get(edge.get_src_id()),
- "target": vid_to_id_map.get(edge.get_dst_id()),
- }
- )
-
- return {"core_node": core_node, "neighbors": neighbors, "edges": edges}
-
- @timed
- # Search / recall operations
- def search_by_embedding(
- self,
- vector: list[float],
- top_k: int = 5,
- scope: str | None = None,
- status: str | None = None,
- threshold: float | None = None,
- search_filter: dict | None = None,
- user_name: str | None = None,
- **kwargs,
- ) -> list[dict]:
- """
- Retrieve node IDs based on vector similarity.
-
- Args:
- vector (list[float]): The embedding vector representing query semantics.
- top_k (int): Number of top similar nodes to retrieve.
- scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
- status (str, optional): Node status filter (e.g., 'active', 'archived').
- If provided, restricts results to nodes with matching status.
- threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
- search_filter (dict, optional): Additional metadata filters for search results.
- Keys should match node properties, values are the expected values.
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
-
- Notes:
- - This method uses Neo4j native vector indexing to search for similar nodes.
- - If scope is provided, it restricts results to nodes with matching memory_type.
- - If 'status' is provided, only nodes with the matching status will be returned.
- - If threshold is provided, only results with score >= threshold will be returned.
- - If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
- - Typical use case: restrict to 'status = activated' to avoid
- matching archived or merged nodes.
- """
- user_name = user_name if user_name else self.config.user_name
- vector = _normalize(vector)
- dim = len(vector)
- vector_str = ",".join(f"{float(x)}" for x in vector)
- gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])"
- where_clauses = [f"n.{self.dim_field} IS NOT NULL"]
- if scope:
- where_clauses.append(f'n.memory_type = "{scope}"')
- if status:
- where_clauses.append(f'n.status = "{status}"')
- where_clauses.append(f'n.user_name = "{user_name}"')
-
- # Add search_filter conditions
- if search_filter:
- for key, value in search_filter.items():
- if isinstance(value, str):
- where_clauses.append(f'n.{key} = "{value}"')
- else:
- where_clauses.append(f"n.{key} = {value}")
-
- where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
-
- gql = f"""
- let a = {gql_vector}
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- {where_clause}
- ORDER BY inner_product(n.{self.dim_field}, a) DESC
- LIMIT {top_k}
- RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score"""
- try:
- result = self.execute_query(gql)
- except Exception as e:
- logger.error(f"[search_by_embedding] Query failed: {e}")
- return []
-
- try:
- output = []
- for row in result:
- values = row.values()
- id_val = values[0].as_string()
- score_val = values[1].as_double()
- score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
- if threshold is None or score_val >= threshold:
- output.append({"id": id_val, "score": score_val})
- return output
- except Exception as e:
- logger.error(f"[search_by_embedding] Result parse failed: {e}")
- return []
-
- @timed
- def get_by_metadata(
- self, filters: list[dict[str, Any]], user_name: str | None = None
- ) -> list[str]:
- """
- 1. ADD logic: "AND" vs "OR"(support logic combination);
- 2. Support nested conditional expressions;
-
- Retrieve node IDs that match given metadata filters.
- Supports exact match.
-
- Args:
- filters: List of filter dicts like:
- [
- {"field": "key", "op": "in", "value": ["A", "B"]},
- {"field": "confidence", "op": ">=", "value": 80},
- {"field": "tags", "op": "contains", "value": "AI"},
- ...
- ]
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- list[str]: Node IDs whose metadata match the filter conditions. (AND logic).
-
- Notes:
- - Supports structured querying such as tag/category/importance/time filtering.
- - Can be used for faceted recall or prefiltering before embedding rerank.
- """
- where_clauses = []
- user_name = user_name if user_name else self.config.user_name
- for _i, f in enumerate(filters):
- field = f["field"]
- op = f.get("op", "=")
- value = f["value"]
-
- escaped_value = self._format_value(value)
-
- # Build WHERE clause
- if op == "=":
- where_clauses.append(f"n.{field} = {escaped_value}")
- elif op == "in":
- where_clauses.append(f"n.{field} IN {escaped_value}")
- elif op == "contains":
- where_clauses.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0")
- elif op == "starts_with":
- where_clauses.append(f"n.{field} STARTS WITH {escaped_value}")
- elif op == "ends_with":
- where_clauses.append(f"n.{field} ENDS WITH {escaped_value}")
- elif op in [">", ">=", "<", "<="]:
- where_clauses.append(f"n.{field} {op} {escaped_value}")
- else:
- raise ValueError(f"Unsupported operator: {op}")
-
- where_clauses.append(f'n.user_name = "{user_name}"')
-
- where_str = " AND ".join(where_clauses)
- gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id"
- ids = []
- try:
- result = self.execute_query(gql)
- ids = [record["id"].value for record in result]
- except Exception as e:
- logger.error(f"Failed to get metadata: {e}, gql is {gql}")
- return ids
-
- @timed
- def get_grouped_counts(
- self,
- group_fields: list[str],
- where_clause: str = "",
- params: dict[str, Any] | None = None,
- user_name: str | None = None,
- ) -> list[dict[str, Any]]:
- """
- Count nodes grouped by any fields.
-
- Args:
- group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"]
- where_clause (str, optional): Extra WHERE condition. E.g.,
- "WHERE n.status = 'activated'"
- params (dict, optional): Parameters for WHERE clause.
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...]
- """
- if not group_fields:
- raise ValueError("group_fields cannot be empty")
- user_name = user_name if user_name else self.config.user_name
- # GQL-specific modifications
- user_clause = f"n.user_name = '{user_name}'"
- if where_clause:
- where_clause = where_clause.strip()
- if where_clause.upper().startswith("WHERE"):
- where_clause += f" AND {user_clause}"
- else:
- where_clause = f"WHERE {where_clause} AND {user_clause}"
- else:
- where_clause = f"WHERE {user_clause}"
-
- # Inline parameters if provided
- if params:
- for key, value in params.items():
- # Handle different value types appropriately
- if isinstance(value, str):
- value = f"'{value}'"
- where_clause = where_clause.replace(f"${key}", str(value))
-
- return_fields = []
- group_by_fields = []
-
- for field in group_fields:
- alias = field.replace(".", "_")
- return_fields.append(f"n.{field} AS {alias}")
- group_by_fields.append(alias)
- # Full GQL query construction
- gql = f"""
- MATCH (n /*+ INDEX(idx_memory_user_name) */)
- {where_clause}
- RETURN {", ".join(return_fields)}, COUNT(n) AS count
- """
- result = self.execute_query(gql) # Pure GQL string execution
-
- output = []
- for record in result:
- group_values = {}
- for i, field in enumerate(group_fields):
- value = record.values()[i].as_string()
- group_values[field] = value
- count_value = record["count"].value
- output.append({**group_values, "count": count_value})
-
- return output
-
- @timed
- def clear(self, user_name: str | None = None) -> None:
- """
- Clear the entire graph if the target database exists.
-
- Args:
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- user_name = user_name if user_name else self.config.user_name
- try:
- query = f"MATCH (n@Memory) WHERE n.user_name = '{user_name}' DETACH DELETE n"
- self.execute_query(query)
- logger.info("Cleared all nodes from database.")
-
- except Exception as e:
- logger.error(f"[ERROR] Failed to clear database: {e}")
-
- @timed
- def export_graph(
- self, include_embedding: bool = False, user_name: str | None = None, **kwargs
- ) -> dict[str, Any]:
- """
- Export all graph nodes and edges in a structured form.
- Args:
- include_embedding (bool): Whether to include the large embedding field.
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- {
- "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
- "edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
- }
- """
- user_name = user_name if user_name else self.config.user_name
- node_query = "MATCH (n@Memory)"
- edge_query = "MATCH (a@Memory)-[r]->(b@Memory)"
- node_query += f' WHERE n.user_name = "{user_name}"'
- edge_query += f' WHERE r.user_name = "{user_name}"'
-
- try:
- if include_embedding:
- return_fields = "n"
- else:
- return_fields = ",".join(
- [
- "n.id AS id",
- "n.memory AS memory",
- "n.user_name AS user_name",
- "n.user_id AS user_id",
- "n.session_id AS session_id",
- "n.status AS status",
- "n.key AS key",
- "n.confidence AS confidence",
- "n.tags AS tags",
- "n.created_at AS created_at",
- "n.updated_at AS updated_at",
- "n.memory_type AS memory_type",
- "n.sources AS sources",
- "n.source AS source",
- "n.node_type AS node_type",
- "n.visibility AS visibility",
- "n.usage AS usage",
- "n.background AS background",
- ]
- )
-
- full_node_query = f"{node_query} RETURN {return_fields}"
- node_result = self.execute_query(full_node_query, timeout=20)
- nodes = []
- logger.debug(f"Debugging: {node_result}")
- for row in node_result:
- if include_embedding:
- props = row.values()[0].as_node().get_properties()
- else:
- props = {k: v.value for k, v in row.items()}
- node = self._parse_node(props)
- nodes.append(node)
- except Exception as e:
- raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e
-
- try:
- full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge"
- edge_result = self.execute_query(full_edge_query, timeout=20)
- edges = [
- {
- "source": row.values()[0].value,
- "target": row.values()[1].value,
- "type": row.values()[2].value,
- }
- for row in edge_result
- ]
- except Exception as e:
- raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e
-
- return {"nodes": nodes, "edges": edges}
-
- @timed
- def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None:
- """
- Import the entire graph from a serialized dictionary.
-
- Args:
- data: A dictionary containing all nodes and edges to be loaded.
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- user_name = user_name if user_name else self.config.user_name
- for node in data.get("nodes", []):
- try:
- id, memory, metadata = _compose_node(node)
- metadata["user_name"] = user_name
- metadata = self._prepare_node_metadata(metadata)
- metadata.update({"id": id, "memory": memory})
- properties = ", ".join(
- f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
- )
- node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
- self.execute_query(node_gql)
- except Exception as e:
- logger.error(f"Fail to load node: {node}, error: {e}")
-
- for edge in data.get("edges", []):
- try:
- source_id, target_id = edge["source"], edge["target"]
- edge_type = edge["type"]
- props = f'{{user_name: "{user_name}"}}'
- edge_gql = f'''
- MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
- INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
- '''
- self.execute_query(edge_gql)
- except Exception as e:
- logger.error(f"Fail to load edge: {edge}, error: {e}")
-
- @timed
- def get_all_memory_items(
- self, scope: str, include_embedding: bool = False, user_name: str | None = None
- ) -> (list)[dict]:
- """
- Retrieve all memory items of a specific memory_type.
-
- Args:
- scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
- include_embedding: with/without embedding
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- list[dict]: Full list of memory items under this scope.
- """
- user_name = user_name if user_name else self.config.user_name
- if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}:
- raise ValueError(f"Unsupported memory type scope: {scope}")
-
- where_clause = f"WHERE n.memory_type = '{scope}'"
- where_clause += f" AND n.user_name = '{user_name}'"
-
- return_fields = self._build_return_fields(include_embedding)
-
- query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- {where_clause}
- RETURN {return_fields}
- LIMIT 100
- """
- nodes = []
- try:
- results = self.execute_query(query)
- for row in results:
- props = {k: v.value for k, v in row.items()}
- nodes.append(self._parse_node(props))
- except Exception as e:
- logger.error(f"Failed to get memories: {e}")
- return nodes
-
- @timed
- def get_structure_optimization_candidates(
- self, scope: str, include_embedding: bool = False, user_name: str | None = None
- ) -> list[dict]:
- """
- Find nodes that are likely candidates for structure optimization:
- - Isolated nodes, nodes with empty background, or nodes with exactly one child.
- - Plus: the child of any parent node that has exactly one child.
- """
- user_name = user_name if user_name else self.config.user_name
- where_clause = f'''
- n.memory_type = "{scope}"
- AND n.status = "activated"
- '''
- where_clause += f' AND n.user_name = "{user_name}"'
-
- return_fields = self._build_return_fields(include_embedding)
- return_fields += f", n.{self.dim_field} AS {self.dim_field}"
-
- query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE {where_clause}
- OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
- OPTIONAL MATCH (p@Memory)-[@PARENT]->(n)
- WHERE c IS NULL AND p IS NULL
- RETURN {return_fields}
- """
-
- candidates = []
- node_ids = set()
- try:
- results = self.execute_query(query)
- for row in results:
- props = {k: v.value for k, v in row.items()}
- node = self._parse_node(props)
- node_id = node["id"]
- if node_id not in node_ids:
- candidates.append(node)
- node_ids.add(node_id)
- except Exception as e:
- logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
- return candidates
-
- @timed
- def drop_database(self) -> None:
- """
- Permanently delete the entire database this instance is using.
- WARNING: This operation is destructive and cannot be undone.
- """
- raise ValueError(
- f"Refusing to drop protected database: `{self.db_name}` in "
- f"Shared Database Multi-Tenant mode"
- )
-
- @timed
- def detect_conflicts(self) -> list[tuple[str, str]]:
- """
- Detect conflicting nodes based on logical or semantic inconsistency.
- Returns:
- A list of (node_id1, node_id2) tuples that conflict.
- """
- raise NotImplementedError
-
- @timed
- # Structure Maintenance
- def deduplicate_nodes(self) -> None:
- """
- Deduplicate redundant or semantically similar nodes.
- This typically involves identifying nodes with identical or near-identical memory.
- """
- raise NotImplementedError
-
- @timed
- def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
- """
- Get the ordered context chain starting from a node, following a relationship type.
- Args:
- id: Starting node ID.
- type: Relationship type to follow (e.g., 'FOLLOWS').
- Returns:
- List of ordered node IDs in the chain.
- """
- raise NotImplementedError
-
- @timed
- def get_neighbors(
- self, id: str, type: str, direction: Literal["in", "out", "both"] = "out"
- ) -> list[str]:
- """
- Get connected node IDs in a specific direction and relationship type.
- Args:
- id: Source node ID.
- type: Relationship type.
- direction: Edge direction to follow ('out', 'in', or 'both').
- Returns:
- List of neighboring node IDs.
- """
- raise NotImplementedError
-
- @timed
- def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]:
- """
- Get the path of nodes from source to target within a limited depth.
- Args:
- source_id: Starting node ID.
- target_id: Target node ID.
- max_depth: Maximum path length to traverse.
- Returns:
- Ordered list of node IDs along the path.
- """
- raise NotImplementedError
-
- @timed
- def merge_nodes(self, id1: str, id2: str) -> str:
- """
- Merge two similar or duplicate nodes into one.
- Args:
- id1: First node ID.
- id2: Second node ID.
- Returns:
- ID of the resulting merged node.
- """
- raise NotImplementedError
-
- @classmethod
- def _ensure_space_exists(cls, tmp_client, cfg):
- """Lightweight check to ensure target graph (space) exists."""
- db_name = getattr(cfg, "space", None)
- if not db_name:
- logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.")
- return
-
- try:
- res = tmp_client.execute("SHOW GRAPHS")
- existing = {row.values()[0].as_string() for row in res}
- if db_name not in existing:
- tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type")
- logger.info(f"โ
Graph `{db_name}` created before session binding.")
- else:
- logger.debug(f"Graph `{db_name}` already exists.")
- except Exception:
- logger.exception("[NebulaGraphDBSync] Failed to ensure space exists")
-
- @timed
- def _ensure_database_exists(self):
- graph_type_name = "MemOSBgeM3Type"
-
- check_type_query = "SHOW GRAPH TYPES"
- result = self.execute_query(check_type_query, auto_set_db=False)
-
- type_exists = any(row["graph_type"].as_string() == graph_type_name for row in result)
-
- if not type_exists:
- create_tag = f"""
- CREATE GRAPH TYPE IF NOT EXISTS {graph_type_name} AS {{
- NODE Memory (:MemoryTag {{
- id STRING,
- memory STRING,
- user_name STRING,
- user_id STRING,
- session_id STRING,
- status STRING,
- key STRING,
- confidence FLOAT,
- tags LIST,
- created_at STRING,
- updated_at STRING,
- memory_type STRING,
- sources LIST,
- source STRING,
- node_type STRING,
- visibility STRING,
- usage LIST,
- background STRING,
- {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT>,
- PRIMARY KEY(id)
- }}),
- EDGE RELATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE PARENT (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE AGGREGATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE MERGED_TO (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE INFERS (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE FOLLOWS (Memory) -[{{user_name STRING}}]-> (Memory)
- }}
- """
- self.execute_query(create_tag, auto_set_db=False)
- else:
- describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name}"
- desc_result = self.execute_query(describe_query, auto_set_db=False)
-
- memory_fields = []
- for row in desc_result:
- field_name = row.values()[0].as_string()
- memory_fields.append(field_name)
-
- if self.dim_field not in memory_fields:
- alter_query = f"""
- ALTER GRAPH TYPE {graph_type_name} {{
- ALTER NODE TYPE Memory ADD PROPERTIES {{ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT> }}
- }}
- """
- self.execute_query(alter_query, auto_set_db=False)
- logger.info(f"โ
Add new vector search {self.dim_field} to {graph_type_name}")
- else:
- logger.info(f"โ
Graph Type {graph_type_name} already include {self.dim_field}")
-
- create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
- try:
- self.execute_query(create_graph, auto_set_db=False)
- logger.info(f"โ
Graph ``{self.db_name}`` is now the working graph.")
- except Exception as e:
- logger.error(f"โ Failed to create tag: {e} trace: {traceback.format_exc()}")
-
- @timed
- def _create_vector_index(
- self,
- label: str = "Memory",
- vector_property: str = "embedding",
- dimensions: int = 3072,
- index_name: str = "memory_vector_index",
- ) -> None:
- """
- Create a vector index for the specified property in the label.
- """
- if str(dimensions) == str(self.default_memory_dimension):
- index_name = f"idx_{vector_property}"
- vector_name = vector_property
- else:
- index_name = f"idx_{vector_property}_{dimensions}"
- vector_name = f"{vector_property}_{dimensions}"
-
- create_vector_index = f"""
- CREATE VECTOR INDEX IF NOT EXISTS {index_name}
- ON NODE {label}::{vector_name}
- OPTIONS {{
- DIM: {dimensions},
- METRIC: IP,
- TYPE: IVF,
- NLIST: 100,
- TRAINSIZE: 1000
- }}
- FOR `{self.db_name}`
- """
- self.execute_query(create_vector_index)
- logger.info(
- f"โ
Ensure {label}::{vector_property} vector index {index_name} "
- f"exists (DIM={dimensions})"
- )
-
- @timed
- def _create_basic_property_indexes(self) -> None:
- """
- Create standard B-tree indexes on status, memory_type, created_at
- and updated_at fields.
- Create standard B-tree indexes on user_name when use Shared Database
- Multi-Tenant Mode.
- """
- fields = [
- "status",
- "memory_type",
- "created_at",
- "updated_at",
- "user_name",
- ]
-
- for field in fields:
- index_name = f"idx_memory_{field}"
- gql = f"""
- CREATE INDEX IF NOT EXISTS {index_name} ON NODE Memory({field})
- FOR `{self.db_name}`
- """
- try:
- self.execute_query(gql)
- logger.info(f"โ
Created index: {index_name} on field {field}")
- except Exception as e:
- logger.error(
- f"โ Failed to create index {index_name}: {e}, trace: {traceback.format_exc()}"
- )
-
- @timed
- def _index_exists(self, index_name: str) -> bool:
- """
- Check if an index with the given name exists.
- """
- """
- Check if a vector index with the given name exists in NebulaGraph.
-
- Args:
- index_name (str): The name of the index to check.
-
- Returns:
- bool: True if the index exists, False otherwise.
- """
- query = "SHOW VECTOR INDEXES"
- try:
- result = self.execute_query(query)
- return any(row.values()[0].as_string() == index_name for row in result)
- except Exception as e:
- logger.error(f"[Nebula] Failed to check index existence: {e}")
- return False
-
- @timed
- def _parse_value(self, value: Any) -> Any:
- """turn Nebula ValueWrapper to Python type"""
- from nebulagraph_python.value_wrapper import ValueWrapper
-
- if value is None or (hasattr(value, "is_null") and value.is_null()):
- return None
- try:
- prim = value.cast_primitive() if isinstance(value, ValueWrapper) else value
- except Exception as e:
- logger.warning(f"Error when decode Nebula ValueWrapper: {e}")
- prim = value.cast() if isinstance(value, ValueWrapper) else value
-
- if isinstance(prim, ValueWrapper):
- return self._parse_value(prim)
- if isinstance(prim, list):
- return [self._parse_value(v) for v in prim]
- if type(prim).__name__ == "NVector":
- return list(prim.values)
-
- return prim # already a Python primitive
-
- def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]:
- parsed = {k: self._parse_value(v) for k, v in props.items()}
-
- for tf in ("created_at", "updated_at"):
- if tf in parsed and parsed[tf] is not None:
- parsed[tf] = _normalize_datetime(parsed[tf])
-
- node_id = parsed.pop("id")
- memory = parsed.pop("memory", "")
- parsed.pop("user_name", None)
- metadata = parsed
- metadata["type"] = metadata.pop("node_type")
-
- if self.dim_field in metadata:
- metadata["embedding"] = metadata.pop(self.dim_field)
-
- return {"id": node_id, "memory": memory, "metadata": metadata}
-
- @timed
- def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """
- Ensure metadata has proper datetime fields and normalized types.
-
- - Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
- - Convert embedding to list of float if present.
- """
- now = datetime.utcnow().isoformat()
- metadata["node_type"] = metadata.pop("type")
-
- # Fill timestamps if missing
- metadata.setdefault("created_at", now)
- metadata.setdefault("updated_at", now)
-
- # Normalize embedding type
- embedding = metadata.get("embedding")
- if embedding and isinstance(embedding, list):
- metadata.pop("embedding")
- metadata[self.dim_field] = _normalize([float(x) for x in embedding])
-
- return metadata
-
- @timed
- def _format_value(self, val: Any, key: str = "") -> str:
- from nebulagraph_python.py_data_types import NVector
-
- # None
- if val is None:
- return "NULL"
- # bool
- if isinstance(val, bool):
- return "true" if val else "false"
- # str
- if isinstance(val, str):
- return f'"{_escape_str(val)}"'
- # num
- elif isinstance(val, (int | float)):
- return str(val)
- # time
- elif isinstance(val, datetime):
- return f'datetime("{val.isoformat()}")'
- # list
- elif isinstance(val, list):
- if key == self.dim_field:
- dim = len(val)
- joined = ",".join(str(float(x)) for x in val)
- return f"VECTOR<{dim}, FLOAT>([{joined}])"
- else:
- return f"[{', '.join(self._format_value(v) for v in val)}]"
- # NVector
- elif isinstance(val, NVector):
- if key == self.dim_field:
- dim = len(val)
- joined = ",".join(str(float(x)) for x in val)
- return f"VECTOR<{dim}, FLOAT>([{joined}])"
- else:
- logger.warning("Invalid NVector")
- # dict
- if isinstance(val, dict):
- j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
- return f'"{_escape_str(j)}"'
- else:
- return f'"{_escape_str(str(val))}"'
-
- @timed
- def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """
- Filter and validate metadata dictionary against the Memory node schema.
- - Removes keys not in schema.
- - Warns if required fields are missing.
- """
-
- dim_fields = {self.dim_field}
-
- allowed_fields = self.common_fields | dim_fields
-
- missing_fields = allowed_fields - metadata.keys()
- if missing_fields:
- logger.info(f"Metadata missing required fields: {sorted(missing_fields)}")
-
- filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
-
- return filtered_metadata
-
- def _build_return_fields(self, include_embedding: bool = False) -> str:
- fields = set(self.base_fields)
- if include_embedding:
- fields.add(self.dim_field)
- return ", ".join(f"n.{f} AS {f}" for f in fields)
diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py
index 33eb39692..56c3e08a0 100644
--- a/src/memos/graph_dbs/neo4j.py
+++ b/src/memos/graph_dbs/neo4j.py
@@ -39,7 +39,7 @@ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
metadata["embedding"] = [float(x) for x in embedding]
# serialization
- if metadata["sources"]:
+ if metadata.get("sources"):
for idx in range(len(metadata["sources"])):
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
return metadata
@@ -72,8 +72,35 @@ def _flatten_info_fields(metadata: dict[str, Any]) -> dict[str, Any]:
return metadata
+def _sanitize_neo4j_value(value: Any) -> Any:
+ """Convert values unsupported by Neo4j properties into safe serializations."""
+ if value is None or isinstance(value, str | int | float | bool):
+ return value
+
+ if isinstance(value, list):
+ if all(item is None or isinstance(item, str | int | float | bool) for item in value):
+ return value
+ return [
+ json.dumps(item, ensure_ascii=False) if isinstance(item, dict | list) else str(item)
+ for item in value
+ ]
+
+ if isinstance(value, dict):
+ return json.dumps(value, ensure_ascii=False, sort_keys=True)
+
+ return str(value)
+
+
+def _sanitize_neo4j_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
+ """Ensure all metadata values are valid Neo4j property types."""
+ return {key: _sanitize_neo4j_value(value) for key, value in metadata.items()}
+
+
class Neo4jGraphDB(BaseGraphDB):
- """Neo4j-based implementation of a graph memory store."""
+ """Neo4j-based implementation of a graph memory store.
+
+ Requires Neo4j >= 5.18 for vector.similarity.cosine() pre-filtering support.
+ """
@require_python_package(
import_name="neo4j",
@@ -209,6 +236,9 @@ def add_node(
# Flatten info fields to top level (for Neo4j flat structure)
metadata = _flatten_info_fields(metadata)
+ # Ensure Neo4j property compatibility (no nested map/list-of-map values)
+ metadata = _sanitize_neo4j_metadata(metadata)
+
# Initialize delete_time and delete_record_id fields
metadata.setdefault("delete_time", "")
metadata.setdefault("delete_record_id", "")
@@ -226,7 +256,7 @@ def add_node(
"""
# serialization
- if metadata["sources"]:
+ if metadata.get("sources"):
for idx in range(len(metadata["sources"])):
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
@@ -843,13 +873,14 @@ def search_by_embedding(
If return_fields is specified, each dict also includes the requested fields.
Notes:
- - This method uses Neo4j native vector indexing to search for similar nodes.
- - If scope is provided, it restricts results to nodes with matching memory_type.
- - If 'status' is provided, only nodes with the matching status will be returned.
+ - When filters are present (scope, status, user_name, etc.), this method uses
+ Neo4j 5.18+ pre-filtering: MATCH + WHERE narrows candidates first, then
+ vector.similarity.cosine() computes similarity only on the filtered set.
+ This avoids the post-filter problem where queryNodes' global top-k excludes
+ the target user's nodes in a multi-tenant shared database.
+ - When no filters are present, the ANN vector index (db.index.vector.queryNodes)
+ is used for maximum efficiency.
- If threshold is provided, only results with score >= threshold will be returned.
- - If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
- - Typical use case: restrict to 'status = activated' to avoid
- matching archived or merged nodes.
"""
user_name = user_name if user_name else self.config.user_name
# Build WHERE clause dynamically
@@ -901,14 +932,28 @@ def search_by_embedding(
if extra_fields:
return_clause = f"RETURN node.id AS id, score, {extra_fields}"
- query = f"""
- CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding)
- YIELD node, score
- {where_clause}
- {return_clause}
- """
-
- parameters = {"embedding": vector, "k": top_k}
+ if where_clause:
+ # Pre-filtering (Neo4j 5.18+): filter nodes first, then compute similarity.
+ # This avoids the post-filter problem where relevant nodes are excluded
+ # from the global top-k returned by queryNodes.
+ where_clause += " AND node.embedding IS NOT NULL"
+ query = f"""
+ MATCH (node:Memory)
+ {where_clause}
+ WITH node, vector.similarity.cosine(node.embedding, $embedding) AS score
+ {return_clause}
+ ORDER BY score DESC
+ LIMIT $top_k
+ """
+ parameters = {"embedding": vector, "top_k": top_k}
+ else:
+ # No filter: use ANN vector index for efficiency.
+ query = f"""
+ CALL db.index.vector.queryNodes('memory_vector_index', $top_k, $embedding)
+ YIELD node, score
+ {return_clause}
+ """
+ parameters = {"embedding": vector, "top_k": top_k}
if scope:
parameters["scope"] = scope
@@ -1842,7 +1887,7 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
if not (
isinstance(node["sources"][idx], str)
and node["sources"][idx][0] == "{"
- and node["sources"][idx][0] == "}"
+ and node["sources"][idx][-1] == "}"
):
break
node["sources"][idx] = json.loads(node["sources"][idx])
diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py
index 283e15115..b5c92f40a 100644
--- a/src/memos/graph_dbs/neo4j_community.py
+++ b/src/memos/graph_dbs/neo4j_community.py
@@ -5,7 +5,12 @@
from typing import Any
from memos.configs.graph_db import Neo4jGraphDBConfig
-from memos.graph_dbs.neo4j import Neo4jGraphDB, _flatten_info_fields, _prepare_node_metadata
+from memos.graph_dbs.neo4j import (
+ Neo4jGraphDB,
+ _flatten_info_fields,
+ _prepare_node_metadata,
+ _sanitize_neo4j_metadata,
+)
from memos.log import get_logger
from memos.vec_dbs.factory import VecDBFactory
from memos.vec_dbs.item import VecDBItem
@@ -55,40 +60,43 @@ def add_node(
# Safely process metadata
metadata = _prepare_node_metadata(metadata)
+ metadata = _flatten_info_fields(metadata)
+ metadata = _sanitize_neo4j_metadata(metadata)
# Initialize delete_time and delete_record_id fields
metadata.setdefault("delete_time", "")
metadata.setdefault("delete_record_id", "")
# serialization
- if metadata["sources"]:
+ if metadata.get("sources"):
for idx in range(len(metadata["sources"])):
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
# Extract required fields
embedding = metadata.pop("embedding", None)
- if embedding is None:
- raise ValueError(f"Missing 'embedding' in metadata for node {id}")
# Merge node and set metadata
created_at = metadata.pop("created_at")
updated_at = metadata.pop("updated_at")
- vector_sync_status = "success"
+ vector_sync_status = "skipped"
- try:
- # Write to Vector DB
- item = VecDBItem(
- id=id,
- vector=embedding,
- payload={
- "memory": memory,
- "vector_sync": vector_sync_status,
- **metadata, # unpack all metadata keys to top-level
- },
- )
- self.vec_db.add([item])
- except Exception as e:
- logger.warning(f"[VecDB] Vector insert failed for node {id}: {e}")
- vector_sync_status = "failed"
+ if embedding is not None:
+ vector_sync_status = "success"
+ try:
+ item = VecDBItem(
+ id=id,
+ vector=embedding,
+ payload={
+ "memory": memory,
+ "vector_sync": vector_sync_status,
+ **metadata,
+ },
+ )
+ self.vec_db.add([item])
+ except Exception as e:
+ logger.warning(f"[VecDB] Vector insert failed for node {id}: {e}")
+ vector_sync_status = "failed"
+ else:
+ logger.warning(f"[add_node] No embedding for node {id}, skipping vector DB insert")
metadata["vector_sync"] = vector_sync_status
query = """
@@ -134,6 +142,7 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N
metadata = _prepare_node_metadata(metadata)
metadata = _flatten_info_fields(metadata)
+ metadata = _sanitize_neo4j_metadata(metadata)
# Initialize delete_time and delete_record_id fields
metadata.setdefault("delete_time", "")
@@ -141,18 +150,21 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N
embedding = metadata.pop("embedding", None)
- vector_sync_status = "success"
- vec_items.append(
- VecDBItem(
- id=node_id,
- vector=embedding,
- payload={
- "memory": memory,
- "vector_sync": vector_sync_status,
- **metadata,
- },
+ if embedding is not None:
+ vector_sync_status = "success"
+ vec_items.append(
+ VecDBItem(
+ id=node_id,
+ vector=embedding,
+ payload={
+ "memory": memory,
+ "vector_sync": vector_sync_status,
+ **metadata,
+ },
+ )
)
- )
+ else:
+ vector_sync_status = "skipped"
created_at = metadata.pop("created_at")
updated_at = metadata.pop("updated_at")
@@ -889,6 +901,14 @@ def build_filter_condition(
if condition_str:
where_clauses.append(f"({condition_str})")
filter_params.update(filter_params_inner)
+ else:
+ # Simple dict syntax: {"user_id": "...", "session_id": "..."}
+ condition_str, filter_params_inner = build_filter_condition(
+ filter, param_counter
+ )
+ if condition_str:
+ where_clauses.append(f"({condition_str})")
+ filter_params.update(filter_params_inner)
where_str = " AND ".join(where_clauses) if where_clauses else ""
if where_str:
@@ -919,7 +939,7 @@ def build_filter_condition(
def delete_node_by_prams(
self,
- writable_cube_ids: list[str],
+ writable_cube_ids: list[str] | None = None,
memory_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
@@ -928,7 +948,7 @@ def delete_node_by_prams(
Delete nodes by memory_ids, file_ids, or filter.
Args:
- writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
+ writable_cube_ids (list[str], optional): List of cube IDs (user_name) to scope deletion.
memory_ids (list[str], optional): List of memory node IDs to delete.
file_ids (list[str], optional): List of file node IDs to delete.
filter (dict, optional): Filter dictionary to query matching nodes for deletion.
@@ -943,9 +963,9 @@ def delete_node_by_prams(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)
- # Validate writable_cube_ids
- if not writable_cube_ids or len(writable_cube_ids) == 0:
- raise ValueError("writable_cube_ids is required and cannot be empty")
+ # file_ids deletion must be scoped by writable_cube_ids.
+ if file_ids and (not writable_cube_ids or len(writable_cube_ids) == 0):
+ raise ValueError("writable_cube_ids is required when deleting by file_ids")
# Build WHERE conditions separately for memory_ids and file_ids
where_clauses = []
@@ -953,10 +973,11 @@ def delete_node_by_prams(
# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
user_name_conditions = []
- for idx, cube_id in enumerate(writable_cube_ids):
- param_name = f"cube_id_{idx}"
- user_name_conditions.append(f"n.user_name = ${param_name}")
- params[param_name] = cube_id
+ if writable_cube_ids:
+ for idx, cube_id in enumerate(writable_cube_ids):
+ param_name = f"cube_id_{idx}"
+ user_name_conditions.append(f"n.user_name = ${param_name}")
+ params[param_name] = cube_id
# Handle memory_ids: query n.id
if memory_ids and len(memory_ids) > 0:
@@ -1003,9 +1024,12 @@ def delete_node_by_prams(
# First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
data_conditions = " OR ".join([f"({clause})" for clause in where_clauses])
- # Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
- user_name_where = " OR ".join(user_name_conditions)
- ids_where = f"({user_name_where}) AND ({data_conditions})"
+ # Then, combine with user_name condition using AND when scope is provided.
+ if user_name_conditions:
+ user_name_where = " OR ".join(user_name_conditions)
+ ids_where = f"({user_name_where}) AND ({data_conditions})"
+ else:
+ ids_where = data_conditions
logger.info(
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
@@ -1138,12 +1162,12 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
node[time_field] = node[time_field].isoformat()
node.pop("user_name", None)
# serialization
- if node["sources"]:
+ if node.get("sources"):
for idx in range(len(node["sources"])):
if not (
isinstance(node["sources"][idx], str)
and node["sources"][idx][0] == "{"
- and node["sources"][idx][0] == "}"
+ and node["sources"][idx][-1] == "}"
):
break
node["sources"][idx] = json.loads(node["sources"][idx])
@@ -1179,7 +1203,7 @@ def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]]
if not (
isinstance(node["sources"][idx], str)
and node["sources"][idx][0] == "{"
- and node["sources"][idx][0] == "}"
+ and node["sources"][idx][-1] == "}"
):
break
node["sources"][idx] = json.loads(node["sources"][idx])
diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py
index 4d88844df..abfae7710 100644
--- a/src/memos/graph_dbs/polardb.py
+++ b/src/memos/graph_dbs/polardb.py
@@ -165,7 +165,7 @@ def __init__(self, config: PolarDBGraphDBConfig):
# Create connection pool
self.connection_pool = psycopg2.pool.ThreadedConnectionPool(
- minconn=5,
+ minconn=1,
maxconn=maxconn,
host=host,
port=port,
@@ -176,6 +176,7 @@ def __init__(self, config: PolarDBGraphDBConfig):
keepalives_idle=120, # Seconds of inactivity before sending keepalive (should be < server idle timeout)
keepalives_interval=15, # Seconds between keepalive retries
keepalives_count=5, # Number of keepalive retries before considering connection dead
+ keepalives=1,
options=f"-c search_path={self.db_name}_graph,ag_catalog,$user,public",
)
@@ -250,39 +251,40 @@ def _get_connection(self):
import psycopg2
timeout = self._connection_wait_timeout
- if not self._semaphore.acquire(timeout=max(timeout, 0)):
+ if timeout is None or timeout <= 0:
+ self._semaphore.acquire()
+ elif not self._semaphore.acquire(timeout=timeout):
logger.warning(f"Timeout waiting for connection slot ({timeout}s)")
raise RuntimeError("Connection pool busy")
- logger.info(
- "Connection pool usage: %s/%s",
- self.connection_pool.maxconn - self._semaphore._value,
- self.connection_pool.maxconn,
- )
+
conn = None
broken = False
try:
- conn = self.connection_pool.getconn()
- conn.autocommit = True
for attempt in range(2):
+ conn = self.connection_pool.getconn()
+ conn.autocommit = True
try:
with conn.cursor() as cur:
cur.execute("SELECT 1")
break
except psycopg2.Error:
- logger.warning("Dead connection detected, recreating (attempt %d)", attempt + 1)
+ logger.warning(f"Dead connection detected, recreating (attempt {attempt + 1})")
self.connection_pool.putconn(conn, close=True)
- conn = self.connection_pool.getconn()
- conn.autocommit = True
+ conn = None
else:
raise RuntimeError("Cannot obtain valid DB connection after 2 attempts")
with conn.cursor() as cur:
cur.execute(f'SET search_path = {self.db_name}_graph, ag_catalog, "$user", public;')
yield conn
- except Exception:
+ except (psycopg2.Error, psycopg2.OperationalError) as e:
broken = True
+ logger.exception(f"Database connection busy : {e}")
+ raise
+ except Exception as e:
+ logger.exception(f"Unexpected error: {e}")
raise
finally:
- if conn:
+ if conn is not None:
try:
self.connection_pool.putconn(conn, close=broken)
logger.debug(f"Returned connection {id(conn)} to pool (broken={broken})")
@@ -441,7 +443,7 @@ def remove_oldest_memory(
)
user_name = user_name if user_name else self._get_config_value("user_name")
- # Use actual OFFSET logic, consistent with nebular.py
+ # Use actual OFFSET logic for deterministic pruning
# First find IDs to delete, then delete them
select_query = f"""
SELECT id FROM "{self.db_name}_graph"."Memory"
@@ -1798,7 +1800,6 @@ def search_by_fulltext(
FROM "{self.db_name}_graph"."Memory" m
CROSS JOIN q
{where_clause_cte}
- ORDER BY rank DESC
LIMIT {top_k};
"""
params = [tsquery_string]
@@ -2194,7 +2195,7 @@ def get_grouped_counts(
SELECT {", ".join(cte_select_list)}
FROM "{self.db_name}_graph"."Memory"
{where_clause}
- LIMIT 1000
+ LIMIT 100
)
SELECT {outer_select}, count(*) AS count
FROM t
@@ -2409,56 +2410,47 @@ def _extract_special_filter_values(filter_obj):
order_clause = """
ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST,id DESC
"""
+ count_query = f"""
+ SELECT COUNT(*) AS total_count
+ FROM "{self.db_name}_graph"."Memory"
+ {where_clause}
+ """
if include_embedding:
- node_query = f"""
- WITH filtered AS (
- SELECT id, properties, embedding
- FROM "{self.db_name}_graph"."Memory"
- {where_clause}
- )
- SELECT p.id, p.properties, p.embedding, c.total_count
- FROM (SELECT COUNT(*) AS total_count FROM filtered) c
- LEFT JOIN LATERAL (
- SELECT id, properties, embedding
- FROM filtered
- {order_clause}
- {pagination_clause}
- ) p ON TRUE
+ data_query = f"""
+ SELECT id, properties, embedding
+ FROM "{self.db_name}_graph"."Memory"
+ {where_clause}
+ {order_clause}
+ {pagination_clause}
"""
else:
- node_query = f"""
- WITH filtered AS (
- SELECT id, properties
- FROM "{self.db_name}_graph"."Memory"
- {where_clause}
- )
- SELECT p.id, p.properties, c.total_count
- FROM (SELECT COUNT(*) AS total_count FROM filtered) c
- LEFT JOIN LATERAL (
- SELECT id, properties
- FROM filtered
- {order_clause}
- {pagination_clause}
- ) p ON TRUE
+ data_query = f"""
+ SELECT id, properties
+ FROM "{self.db_name}_graph"."Memory"
+ {where_clause}
+ {order_clause}
+ {pagination_clause}
"""
- logger.info(f"[export_graph nodes] Query: {node_query}")
+ logger.info(f"[export_graph nodes] count_query: {count_query}")
+ logger.info(f"[export_graph nodes] data_query: {data_query}")
try:
with self._get_connection() as conn, conn.cursor() as cursor:
- cursor.execute(node_query)
+ cursor.execute(count_query)
+ count_row = cursor.fetchone()
+ total_nodes = int(count_row[0]) if count_row and count_row[0] is not None else 0
+
+ cursor.execute(data_query)
node_results = cursor.fetchall()
nodes = []
for row in node_results:
if include_embedding:
- row_id, properties_json, embedding_json, row_total_count = row
+ row_id, properties_json, embedding_json = row
else:
- row_id, properties_json, row_total_count = row
+ row_id, properties_json = row
embedding_json = None
- if row_total_count is not None:
- total_nodes = int(row_total_count)
-
if row_id is None:
continue
@@ -3539,7 +3531,7 @@ def get_neighbors_by_tag_ccl(
user_name = user_name if user_name else self._get_config_value("user_name")
- # Build query conditions; keep consistent with nebular.py
+ # Build query conditions shared with other graph backends
where_clauses = [
'n.status = "activated"',
'NOT (n.node_type = "reasoning")',
@@ -3592,7 +3584,7 @@ def get_neighbors_by_tag_ccl(
# Add overlap_count
result_fields.append("overlap_count agtype")
result_fields_str = ", ".join(result_fields)
- # Use Cypher query; keep consistent with nebular.py
+ # Use Cypher query to keep the graph query path aligned
query = f"""
SELECT * FROM (
SELECT * FROM cypher('{self.db_name}_graph', $$
diff --git a/src/memos/graph_dbs/postgres.py b/src/memos/graph_dbs/postgres.py
index 1c1cae378..594f7e695 100644
--- a/src/memos/graph_dbs/postgres.py
+++ b/src/memos/graph_dbs/postgres.py
@@ -10,6 +10,7 @@
"""
import json
+import re
import time
from contextlib import suppress
@@ -438,6 +439,202 @@ def _parse_row(self, row, include_embedding: bool = False) -> dict[str, Any]:
result["metadata"]["embedding"] = row[5]
return result
+ @staticmethod
+ def _is_safe_field_name(field: str) -> bool:
+ """Validate field names used in dynamic SQL fragments."""
+ return bool(re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", field))
+
+ def _field_expr(self, key: str) -> tuple[str, str]:
+ """
+ Build text/json SQL expressions for a filter key.
+
+ Returns:
+ tuple[text_expr, json_expr]
+ """
+ direct_columns = {"id", "memory", "user_name", "created_at", "updated_at"}
+ if key in direct_columns:
+ return key, key
+
+ if key.startswith("info."):
+ sub_key = key[5:]
+ if not self._is_safe_field_name(sub_key):
+ raise ValueError(f"Invalid filter field: {key}")
+ return f"properties->'info'->>'{sub_key}'", f"properties->'info'->'{sub_key}'"
+
+ if not self._is_safe_field_name(key):
+ raise ValueError(f"Invalid filter field: {key}")
+ return f"properties->>'{key}'", f"properties->'{key}'"
+
+ def _build_single_filter_condition(
+ self, condition_dict: dict[str, Any], params: list[Any]
+ ) -> str | None:
+ """Build SQL for a single filter condition dict."""
+ if not condition_dict:
+ return None
+
+ array_fields = {"tags", "sources", "file_ids"}
+ timestamp_fields = {"created_at", "updated_at"}
+ parts: list[str] = []
+
+ for key, value in condition_dict.items():
+ text_expr, json_expr = self._field_expr(key)
+ raw_key = key[5:] if key.startswith("info.") else key
+
+ if isinstance(value, dict):
+ for op, op_value in value.items():
+ if op in ("gt", "lt", "gte", "lte"):
+ op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="}
+ sql_op = op_map[op]
+ if raw_key in timestamp_fields or raw_key.endswith("_at"):
+ parts.append(f"({text_expr})::timestamptz {sql_op} %s::timestamptz")
+ params.append(op_value)
+ else:
+ parts.append(f"NULLIF({text_expr}, '')::numeric {sql_op} %s")
+ params.append(op_value)
+ elif op == "contains":
+ if raw_key in array_fields:
+ parts.append(f"{json_expr} @> %s::jsonb")
+ params.append(json.dumps([op_value]))
+ else:
+ parts.append(f"{text_expr} ILIKE %s")
+ params.append(f"%{op_value}%")
+ elif op == "in":
+ if not isinstance(op_value, list):
+ raise ValueError(
+ f"in operator expects list for '{key}', got {type(op_value).__name__}"
+ )
+ if raw_key in array_fields:
+ parts.append(f"{json_expr} ?| %s")
+ params.append([str(v) for v in op_value])
+ else:
+ parts.append(f"{text_expr} = ANY(%s)")
+ params.append([str(v) for v in op_value])
+ elif op == "like":
+ parts.append(f"{text_expr} ILIKE %s")
+ params.append(f"%{op_value}%")
+ else:
+ raise ValueError(f"Unsupported filter operator: {op}")
+ else:
+ if raw_key in array_fields:
+ if isinstance(value, list):
+ parts.append(f"{json_expr} @> %s::jsonb")
+ params.append(json.dumps(value))
+ else:
+ parts.append(f"{json_expr} @> %s::jsonb")
+ params.append(json.dumps([value]))
+ else:
+ parts.append(f"{text_expr} = %s")
+ params.append(str(value))
+
+ if not parts:
+ return None
+ return " AND ".join(parts)
+
+ def _build_filter_where_clause(self, filter_dict: dict[str, Any], params: list[Any]) -> str:
+ """Build SQL WHERE fragment from filter dict."""
+ if not filter_dict:
+ return ""
+
+ if "and" in filter_dict:
+ and_conditions = filter_dict.get("and")
+ if not isinstance(and_conditions, list):
+ raise ValueError("Invalid filter format: 'and' must be a list")
+ parts: list[str] = []
+ for cond in and_conditions:
+ if isinstance(cond, dict):
+ cond_sql = self._build_single_filter_condition(cond, params)
+ if cond_sql:
+ parts.append(f"({cond_sql})")
+ return " AND ".join(parts)
+
+ if "or" in filter_dict:
+ or_conditions = filter_dict.get("or")
+ if not isinstance(or_conditions, list):
+ raise ValueError("Invalid filter format: 'or' must be a list")
+ parts: list[str] = []
+ for cond in or_conditions:
+ if isinstance(cond, dict):
+ cond_sql = self._build_single_filter_condition(cond, params)
+ if cond_sql:
+ parts.append(f"({cond_sql})")
+ return f"({' OR '.join(parts)})" if parts else ""
+
+ cond_sql = self._build_single_filter_condition(filter_dict, params)
+ return cond_sql or ""
+
+ def delete_node_by_prams(
+ self,
+ writable_cube_ids: list[str] | None = None,
+ memory_ids: list[str] | None = None,
+ file_ids: list[str] | None = None,
+ filter: dict | None = None,
+ ) -> int:
+ """Delete nodes by memory_ids, file_ids, or filter."""
+ logger.info(
+ "[delete_node_by_prams] memory_ids: %s, file_ids: %s, filter: %s, writable_cube_ids: %s",
+ memory_ids,
+ file_ids,
+ filter,
+ writable_cube_ids,
+ )
+
+ where_conditions: list[str] = []
+ params: list[Any] = []
+
+ if memory_ids:
+ where_conditions.append("id = ANY(%s)")
+ params.append(memory_ids)
+
+ if file_ids:
+ file_conditions: list[str] = []
+ for file_id in file_ids:
+ file_conditions.append("properties->'file_ids' @> %s::jsonb")
+ params.append(json.dumps([file_id]))
+ if file_conditions:
+ where_conditions.append(f"({' OR '.join(file_conditions)})")
+
+ if filter:
+ filter_where = self._build_filter_where_clause(filter, params)
+ if filter_where:
+ where_conditions.append(f"({filter_where})")
+
+ if not where_conditions:
+ logger.warning(
+ "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
+ )
+ return 0
+
+ if writable_cube_ids:
+ where_conditions.append("user_name = ANY(%s)")
+ params.append(writable_cube_ids)
+
+ where_clause = " AND ".join(where_conditions)
+
+ conn = self._get_conn()
+ try:
+ with conn.cursor() as cur:
+ query = f"""
+ WITH to_delete AS (
+ SELECT id
+ FROM {self.schema}.memories
+ WHERE {where_clause}
+ ),
+ deleted_edges AS (
+ DELETE FROM {self.schema}.edges e
+ USING to_delete d
+ WHERE e.source_id = d.id OR e.target_id = d.id
+ )
+ DELETE FROM {self.schema}.memories m
+ USING to_delete d
+ WHERE m.id = d.id
+ """
+ cur.execute(query, params)
+ deleted_count = cur.rowcount if cur.rowcount is not None else 0
+ logger.info("[delete_node_by_prams] Deleted %s nodes", deleted_count)
+ return deleted_count
+ finally:
+ self._put_conn(conn)
+
# =========================================================================
# Edge Management
# =========================================================================
diff --git a/src/memos/llms/factory.py b/src/memos/llms/factory.py
index 8f4da662f..20b27593e 100644
--- a/src/memos/llms/factory.py
+++ b/src/memos/llms/factory.py
@@ -5,6 +5,7 @@
from memos.llms.deepseek import DeepSeekLLM
from memos.llms.hf import HFLLM
from memos.llms.hf_singleton import HFSingletonLLM
+from memos.llms.minimax import MinimaxLLM
from memos.llms.ollama import OllamaLLM
from memos.llms.openai import AzureLLM, OpenAILLM
from memos.llms.openai_new import OpenAIResponsesLLM
@@ -25,6 +26,7 @@ class LLMFactory(BaseLLM):
"vllm": VLLMLLM,
"qwen": QwenLLM,
"deepseek": DeepSeekLLM,
+ "minimax": MinimaxLLM,
"openai_new": OpenAIResponsesLLM,
}
diff --git a/src/memos/llms/minimax.py b/src/memos/llms/minimax.py
new file mode 100644
index 000000000..3bee9882f
--- /dev/null
+++ b/src/memos/llms/minimax.py
@@ -0,0 +1,13 @@
+from memos.configs.llm import MinimaxLLMConfig
+from memos.llms.openai import OpenAILLM
+from memos.log import get_logger
+
+
+logger = get_logger(__name__)
+
+
+class MinimaxLLM(OpenAILLM):
+ """MiniMax LLM class via OpenAI-compatible API."""
+
+ def __init__(self, config: MinimaxLLMConfig):
+ super().__init__(config)
diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py
index 135058a7d..a57a40676 100644
--- a/src/memos/mem_feedback/feedback.py
+++ b/src/memos/mem_feedback/feedback.py
@@ -2,6 +2,7 @@
import difflib
import json
import re
+import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal
@@ -236,6 +237,7 @@ def _single_add_operation(
else:
to_add_memory = new_memory_item.model_copy(deep=True)
+ to_add_memory.id = str(uuid.uuid4())
if to_add_memory.metadata.memory_type == "PreferenceMemory":
to_add_memory.metadata.preference = new_memory_item.memory
@@ -359,9 +361,14 @@ def semantics_feedback(
lang = detect_lang("".join(memory_item.memory))
template = FEEDBACK_PROMPT_DICT["compare"][lang]
if current_memories == []:
- # retrieve
- last_user_index = max(i for i, d in enumerate(chat_history_list) if d["role"] == "user")
- last_qa = " ".join([item["content"] for item in chat_history_list[last_user_index:]])
+ user_indices = [i for i, d in enumerate(chat_history_list) if d["role"] == "user"]
+ if user_indices:
+ last_user_index = max(user_indices)
+ last_qa = " ".join(
+ [item["content"] for item in chat_history_list[last_user_index:]]
+ )
+ else:
+ last_qa = " ".join([item["content"] for item in chat_history_list])
supplementary_retrieved = self._retrieve(last_qa, info=info, user_name=user_name)
feedback_retrieved = self._retrieve(memory_item.memory, info=info, user_name=user_name)
diff --git a/src/memos/mem_os/client.py b/src/memos/mem_os/client.py
deleted file mode 100644
index f4a591e59..000000000
--- a/src/memos/mem_os/client.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# TODO: @Li Ji
-
-
-class ClientMOS:
- pass
diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py
deleted file mode 100644
index b2c74c384..000000000
--- a/src/memos/mem_os/product.py
+++ /dev/null
@@ -1,1610 +0,0 @@
-import asyncio
-import json
-import os
-import random
-import time
-
-from collections.abc import Generator
-from datetime import datetime
-from typing import Any, Literal
-
-from dotenv import load_dotenv
-from transformers import AutoTokenizer
-
-from memos.configs.mem_cube import GeneralMemCubeConfig
-from memos.configs.mem_os import MOSConfig
-from memos.context.context import ContextThread
-from memos.log import get_logger
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_os.core import MOSCore
-from memos.mem_os.utils.format_utils import (
- clean_json_response,
- convert_graph_to_tree_forworkmem,
- ensure_unique_tree_ids,
- filter_nodes_by_tree_ids,
- remove_embedding_recursive,
- sort_children_by_memory_type,
-)
-from memos.mem_os.utils.reference_utils import (
- prepare_reference_data,
- process_streaming_references_complete,
-)
-from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
-from memos.mem_scheduler.schemas.task_schemas import (
- ANSWER_TASK_LABEL,
- QUERY_TASK_LABEL,
-)
-from memos.mem_user.persistent_factory import PersistentUserManagerFactory
-from memos.mem_user.user_manager import UserRole
-from memos.memories.textual.item import (
- TextualMemoryItem,
-)
-from memos.templates.mos_prompts import (
- FURTHER_SUGGESTION_PROMPT,
- SUGGESTION_QUERY_PROMPT_EN,
- SUGGESTION_QUERY_PROMPT_ZH,
- get_memos_prompt,
-)
-from memos.types import MessageList
-from memos.utils import timed
-
-
-logger = get_logger(__name__)
-
-load_dotenv()
-
-CUBE_PATH = os.getenv("MOS_CUBE_PATH", "/tmp/data/")
-
-
-def _short_id(mem_id: str) -> str:
- return (mem_id or "").split("-")[0] if mem_id else ""
-
-
-def _format_mem_block(memories_all, max_items: int = 20, max_chars_each: int = 320) -> str:
- """
- Modify TextualMemoryItem Format:
- 1:abcd :: [P] text...
- 2:ef01 :: [O] text...
- sequence is [i:memId] i; [P]=PersonalMemory / [O]=OuterMemory
- """
- if not memories_all:
- return "(none)", "(none)"
-
- lines_o = []
- lines_p = []
- for idx, m in enumerate(memories_all[:max_items], 1):
- mid = _short_id(getattr(m, "id", "") or "")
- mtype = getattr(getattr(m, "metadata", {}), "memory_type", None) or getattr(
- m, "metadata", {}
- ).get("memory_type", "")
- tag = "O" if "Outer" in str(mtype) else "P"
- txt = (getattr(m, "memory", "") or "").replace("\n", " ").strip()
- if len(txt) > max_chars_each:
- txt = txt[: max_chars_each - 1] + "โฆ"
- mid = mid or f"mem_{idx}"
- if tag == "O":
- lines_o.append(f"[{idx}:{mid}] :: [{tag}] {txt}\n")
- elif tag == "P":
- lines_p.append(f"[{idx}:{mid}] :: [{tag}] {txt}")
- return "\n".join(lines_o), "\n".join(lines_p)
-
-
-class MOSProduct(MOSCore):
- """
- The MOSProduct class inherits from MOSCore and manages multiple users.
- Each user has their own configuration and cube access, but shares the same model instances.
- """
-
- def __init__(
- self,
- default_config: MOSConfig | None = None,
- max_user_instances: int = 1,
- default_cube_config: GeneralMemCubeConfig | None = None,
- online_bot=None,
- error_bot=None,
- ):
- """
- Initialize MOSProduct with an optional default configuration.
-
- Args:
- default_config (MOSConfig | None): Default configuration for new users
- max_user_instances (int): Maximum number of user instances to keep in memory
- default_cube_config (GeneralMemCubeConfig | None): Default cube configuration for loading cubes
- online_bot: DingDing online_bot function or None if disabled
- error_bot: DingDing error_bot function or None if disabled
- """
- # Initialize with a root config for shared resources
- if default_config is None:
- # Create a minimal config for root user
- root_config = MOSConfig(
- user_id="root",
- session_id="root_session",
- chat_model=default_config.chat_model if default_config else None,
- mem_reader=default_config.mem_reader if default_config else None,
- enable_mem_scheduler=default_config.enable_mem_scheduler
- if default_config
- else False,
- mem_scheduler=default_config.mem_scheduler if default_config else None,
- )
- else:
- root_config = default_config.model_copy(deep=True)
- root_config.user_id = "root"
- root_config.session_id = "root_session"
-
- # Create persistent user manager BEFORE calling parent constructor
- persistent_user_manager_client = PersistentUserManagerFactory.from_config(
- config_factory=root_config.user_manager
- )
-
- # Initialize parent MOSCore with root config and persistent user manager
- super().__init__(root_config, user_manager=persistent_user_manager_client)
-
- # Product-specific attributes
- self.default_config = default_config
- self.default_cube_config = default_cube_config
- self.max_user_instances = max_user_instances
- self.online_bot = online_bot
- self.error_bot = error_bot
-
- # User-specific data structures
- self.user_configs: dict[str, MOSConfig] = {}
- self.user_cube_access: dict[str, set[str]] = {} # user_id -> set of cube_ids
- self.user_chat_histories: dict[str, dict] = {}
-
- # Note: self.user_manager is now the persistent user manager from parent class
- # No need for separate global_user_manager as they are the same instance
-
- # Initialize tiktoken for streaming
- try:
- # Use gpt2 encoding which is more stable and widely compatible
- self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
- logger.info("tokenizer initialized successfully for streaming")
- except Exception as e:
- logger.warning(
- f"Failed to initialize tokenizer, will use character-based chunking: {e}"
- )
- self.tokenizer = None
-
- # Restore user instances from persistent storage
- self._restore_user_instances(default_cube_config=default_cube_config)
- logger.info(f"User instances restored successfully, now user is {self.mem_cubes.keys()}")
-
- def _restore_user_instances(
- self, default_cube_config: GeneralMemCubeConfig | None = None
- ) -> None:
- """Restore user instances from persistent storage after service restart.
-
- Args:
- default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None.
- """
- try:
- # Get all user configurations from persistent storage
- user_configs = self.user_manager.list_user_configs(self.max_user_instances)
-
- # Get the raw database records for sorting by updated_at
- session = self.user_manager._get_session()
- try:
- from memos.mem_user.persistent_user_manager import UserConfig
-
- db_configs = session.query(UserConfig).limit(self.max_user_instances).all()
- # Create a mapping of user_id to updated_at timestamp
- updated_at_map = {config.user_id: config.updated_at for config in db_configs}
-
- # Sort by updated_at timestamp (most recent first) and limit by max_instances
- sorted_configs = sorted(
- user_configs.items(), key=lambda x: updated_at_map.get(x[0], ""), reverse=True
- )[: self.max_user_instances]
- finally:
- session.close()
-
- for user_id, config in sorted_configs:
- if user_id != "root": # Skip root user
- try:
- # Store user config and cube access
- self.user_configs[user_id] = config
- self._load_user_cube_access(user_id)
-
- # Pre-load all cubes for this user with default config
- self._preload_user_cubes(user_id, default_cube_config)
-
- logger.info(
- f"Restored user configuration and pre-loaded cubes for {user_id}"
- )
-
- except Exception as e:
- logger.error(f"Failed to restore user configuration for {user_id}: {e}")
-
- except Exception as e:
- logger.error(f"Error during user instance restoration: {e}")
-
- def _initialize_cube_from_default_config(
- self, cube_id: str, user_id: str, default_config: GeneralMemCubeConfig
- ) -> GeneralMemCube | None:
- """
- Initialize a cube from default configuration when cube path doesn't exist.
-
- Args:
- cube_id (str): The cube ID to initialize.
- user_id (str): The user ID for the cube.
- default_config (GeneralMemCubeConfig): The default configuration to use.
- """
- cube_config = default_config.model_copy(deep=True)
- # Safely modify the graph_db user_name if it exists
- if cube_config.text_mem.config.graph_db.config:
- cube_config.text_mem.config.graph_db.config.user_name = (
- f"memos{user_id.replace('-', '')}"
- )
- mem_cube = GeneralMemCube(config=cube_config)
- return mem_cube
-
- def _preload_user_cubes(
- self, user_id: str, default_cube_config: GeneralMemCubeConfig | None = None
- ) -> None:
- """Pre-load all cubes for a user into memory.
-
- Args:
- user_id (str): The user ID to pre-load cubes for.
- default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None.
- """
- try:
- # Get user's accessible cubes from persistent storage
- accessible_cubes = self.user_manager.get_user_cubes(user_id)
-
- for cube in accessible_cubes:
- if cube.cube_id not in self.mem_cubes:
- try:
- if cube.cube_path and os.path.exists(cube.cube_path):
- # Pre-load cube with all memory types and default config
- self.register_mem_cube(
- cube.cube_path,
- cube.cube_id,
- user_id,
- memory_types=["act_mem"]
- if self.config.enable_activation_memory
- else [],
- default_config=default_cube_config,
- )
- logger.info(f"Pre-loaded cube {cube.cube_id} for user {user_id}")
- else:
- logger.warning(
- f"Cube path {cube.cube_path} does not exist for cube {cube.cube_id}, skipping pre-load"
- )
- except Exception as e:
- logger.error(
- f"Failed to pre-load cube {cube.cube_id} for user {user_id}: {e}",
- exc_info=True,
- )
-
- except Exception as e:
- logger.error(f"Error pre-loading cubes for user {user_id}: {e}", exc_info=True)
-
- @timed
- def _load_user_cubes(
- self, user_id: str, default_cube_config: GeneralMemCubeConfig | None = None
- ) -> None:
- """Load all cubes for a user into memory.
-
- Args:
- user_id (str): The user ID to load cubes for.
- default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None.
- """
- # Get user's accessible cubes from persistent storage
- accessible_cubes = self.user_manager.get_user_cubes(user_id)
-
- for cube in accessible_cubes[:1]:
- if cube.cube_id not in self.mem_cubes:
- try:
- if cube.cube_path and os.path.exists(cube.cube_path):
- # Use MOSCore's register_mem_cube method directly with default config
- # Only load act_mem since text_mem is stored in database
- self.register_mem_cube(
- cube.cube_path,
- cube.cube_id,
- user_id,
- memory_types=["act_mem"],
- default_config=default_cube_config,
- )
- else:
- logger.warning(
- f"Cube path {cube.cube_path} does not exist for cube {cube.cube_id}, now init by default config"
- )
- cube_obj = self._initialize_cube_from_default_config(
- cube_id=cube.cube_id,
- user_id=user_id,
- default_config=default_cube_config,
- )
- if cube_obj:
- self.register_mem_cube(
- cube_obj,
- cube.cube_id,
- user_id,
- memory_types=[],
- )
- else:
- raise ValueError(
- f"Failed to initialize default cube {cube.cube_id} for user {user_id}"
- )
- except Exception as e:
- logger.error(f"Failed to load cube {cube.cube_id} for user {user_id}: {e}")
- logger.info(f"load user {user_id} cubes successfully")
-
- def _ensure_user_instance(self, user_id: str, max_instances: int | None = None) -> None:
- """
- Ensure user configuration exists, creating it if necessary.
-
- Args:
- user_id (str): The user ID
- max_instances (int): Maximum instances to keep in memory (overrides class default)
- """
- if user_id in self.user_configs:
- return
-
- # Try to get config from persistent storage first
- stored_config = self.user_manager.get_user_config(user_id)
- if stored_config:
- self.user_configs[user_id] = stored_config
- self._load_user_cube_access(user_id)
- else:
- # Use default config
- if not self.default_config:
- raise ValueError(f"No configuration available for user {user_id}")
- user_config = self.default_config.model_copy(deep=True)
- user_config.user_id = user_id
- user_config.session_id = f"{user_id}_session"
- self.user_configs[user_id] = user_config
- self._load_user_cube_access(user_id)
-
- # Apply LRU eviction if needed
- max_instances = max_instances or self.max_user_instances
- if len(self.user_configs) > max_instances:
- # Remove least recently used instance (excluding root)
- user_ids = [uid for uid in self.user_configs if uid != "root"]
- if user_ids:
- oldest_user_id = user_ids[0]
- del self.user_configs[oldest_user_id]
- if oldest_user_id in self.user_cube_access:
- del self.user_cube_access[oldest_user_id]
- logger.info(f"Removed least recently used user configuration: {oldest_user_id}")
-
- def _load_user_cube_access(self, user_id: str) -> None:
- """Load user's cube access permissions."""
- try:
- # Get user's accessible cubes from persistent storage
- accessible_cubes = self.user_manager.get_user_cube_access(user_id)
- self.user_cube_access[user_id] = set(accessible_cubes)
- except Exception as e:
- logger.warning(f"Failed to load cube access for user {user_id}: {e}")
- self.user_cube_access[user_id] = set()
-
- def _get_user_config(self, user_id: str) -> MOSConfig:
- """Get user configuration."""
- if user_id not in self.user_configs:
- self._ensure_user_instance(user_id)
- return self.user_configs[user_id]
-
- def _validate_user_cube_access(self, user_id: str, cube_id: str) -> None:
- """Validate user has access to the cube."""
- if user_id not in self.user_cube_access:
- self._load_user_cube_access(user_id)
-
- if cube_id not in self.user_cube_access.get(user_id, set()):
- raise ValueError(f"User '{user_id}' does not have access to cube '{cube_id}'")
-
- def _validate_user_access(self, user_id: str, cube_id: str | None = None) -> None:
- """Validate user access using MOSCore's built-in validation."""
- # Use MOSCore's built-in user validation
- if cube_id:
- self._validate_cube_access(user_id, cube_id)
- else:
- self._validate_user_exists(user_id)
-
- def _create_user_config(self, user_id: str, config: MOSConfig) -> MOSConfig:
- """Create a new user configuration."""
- # Create a copy of config with the specific user_id
- user_config = config.model_copy(deep=True)
- user_config.user_id = user_id
- user_config.session_id = f"{user_id}_session"
-
- # Save configuration to persistent storage
- self.user_manager.save_user_config(user_id, user_config)
-
- return user_config
-
- def _get_or_create_user_config(
- self, user_id: str, config: MOSConfig | None = None
- ) -> MOSConfig:
- """Get existing user config or create a new one."""
- if user_id in self.user_configs:
- return self.user_configs[user_id]
-
- # Try to get config from persistent storage first
- stored_config = self.user_manager.get_user_config(user_id)
- if stored_config:
- return self._create_user_config(user_id, stored_config)
-
- # Use provided config or default config
- user_config = config or self.default_config
- if not user_config:
- raise ValueError(f"No configuration provided for user {user_id}")
-
- return self._create_user_config(user_id, user_config)
-
- def _build_system_prompt(
- self,
- memories_all: list[TextualMemoryItem],
- base_prompt: str | None = None,
- tone: str = "friendly",
- verbosity: str = "mid",
- ) -> str:
- """
- Build custom system prompt for the user with memory references.
-
- Args:
- user_id (str): The user ID.
- memories (list[TextualMemoryItem]): The memories to build the system prompt.
-
- Returns:
- str: The custom system prompt.
- """
- # Build base prompt
- # Add memory context if available
- now = datetime.now()
- formatted_date = now.strftime("%Y-%m-%d (%A)")
- sys_body = get_memos_prompt(
- date=formatted_date, tone=tone, verbosity=verbosity, mode="base"
- )
- mem_block_o, mem_block_p = _format_mem_block(memories_all)
- mem_block = mem_block_o + "\n" + mem_block_p
- prefix = (base_prompt.strip() + "\n\n") if base_prompt else ""
- return (
- prefix
- + sys_body
- + "\n\n# Memories\n## PersonalMemory & OuterMemory (ordered)\n"
- + mem_block
- )
-
- def _build_base_system_prompt(
- self,
- base_prompt: str | None = None,
- tone: str = "friendly",
- verbosity: str = "mid",
- mode: str = "enhance",
- ) -> str:
- """
- Build base system prompt without memory references.
- """
- now = datetime.now()
- formatted_date = now.strftime("%Y-%m-%d (%A)")
- sys_body = get_memos_prompt(date=formatted_date, tone=tone, verbosity=verbosity, mode=mode)
- prefix = (base_prompt.strip() + "\n\n") if base_prompt else ""
- return prefix + sys_body
-
- def _build_memory_context(
- self,
- memories_all: list[TextualMemoryItem],
- mode: str = "enhance",
- ) -> str:
- """
- Build memory context to be included in user message.
- """
- if not memories_all:
- return ""
-
- mem_block_o, mem_block_p = _format_mem_block(memories_all)
-
- if mode == "enhance":
- return (
- "# Memories\n## PersonalMemory (ordered)\n"
- + mem_block_p
- + "\n## OuterMemory (ordered)\n"
- + mem_block_o
- + "\n\n"
- )
- else:
- mem_block = mem_block_o + "\n" + mem_block_p
- return "# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + mem_block + "\n\n"
-
- def _build_enhance_system_prompt(
- self,
- user_id: str,
- memories_all: list[TextualMemoryItem],
- tone: str = "friendly",
- verbosity: str = "mid",
- ) -> str:
- """
- Build enhance prompt for the user with memory references.
- [DEPRECATED] Use _build_base_system_prompt and _build_memory_context instead.
- """
- now = datetime.now()
- formatted_date = now.strftime("%Y-%m-%d (%A)")
- sys_body = get_memos_prompt(
- date=formatted_date, tone=tone, verbosity=verbosity, mode="enhance"
- )
- mem_block_o, mem_block_p = _format_mem_block(memories_all)
- return (
- sys_body
- + "\n\n# Memories\n## PersonalMemory (ordered)\n"
- + mem_block_p
- + "\n## OuterMemory (ordered)\n"
- + mem_block_o
- )
-
- def _extract_references_from_response(self, response: str) -> tuple[str, list[dict]]:
- """
- Extract reference information from the response and return clean text.
-
- Args:
- response (str): The complete response text.
-
- Returns:
- tuple[str, list[dict]]: A tuple containing:
- - clean_text: Text with reference markers removed
- - references: List of reference information
- """
- import re
-
- try:
- references = []
- # Pattern to match [refid:memoriesID]
- pattern = r"\[(\d+):([^\]]+)\]"
-
- matches = re.findall(pattern, response)
- for ref_number, memory_id in matches:
- references.append({"memory_id": memory_id, "reference_number": int(ref_number)})
-
- # Remove all reference markers from the text to get clean text
- clean_text = re.sub(pattern, "", response)
-
- # Clean up any extra whitespace that might be left after removing markers
- clean_text = re.sub(r"\s+", " ", clean_text).strip()
-
- return clean_text, references
- except Exception as e:
- logger.error(f"Error extracting references from response: {e}", exc_info=True)
- return response, []
-
- def _extract_struct_data_from_history(self, chat_data: list[dict]) -> dict:
- """
- get struct message from chat-history
- # TODO: @xcy make this more general
- """
- system_content = ""
- memory_content = ""
- chat_history = []
-
- for item in chat_data:
- role = item.get("role")
- content = item.get("content", "")
- if role == "system":
- parts = content.split("# Memories", 1)
- system_content = parts[0].strip()
- if len(parts) > 1:
- memory_content = "# Memories" + parts[1].strip()
- elif role in ("user", "assistant"):
- chat_history.append({"role": role, "content": content})
-
- if chat_history and chat_history[-1]["role"] == "assistant":
- if len(chat_history) >= 2 and chat_history[-2]["role"] == "user":
- chat_history = chat_history[:-2]
- else:
- chat_history = chat_history[:-1]
-
- return {"system": system_content, "memory": memory_content, "chat_history": chat_history}
-
- def _chunk_response_with_tiktoken(
- self, response: str, chunk_size: int = 5
- ) -> Generator[str, None, None]:
- """
- Chunk response using tiktoken for proper token-based streaming.
-
- Args:
- response (str): The response text to chunk.
- chunk_size (int): Number of tokens per chunk.
-
- Yields:
- str: Chunked text pieces.
- """
- if self.tokenizer:
- # Use tiktoken for proper token-based chunking
- tokens = self.tokenizer.encode(response)
-
- for i in range(0, len(tokens), chunk_size):
- token_chunk = tokens[i : i + chunk_size]
- chunk_text = self.tokenizer.decode(token_chunk)
- yield chunk_text
- else:
- # Fallback to character-based chunking
- char_chunk_size = chunk_size * 4 # Approximate character to token ratio
- for i in range(0, len(response), char_chunk_size):
- yield response[i : i + char_chunk_size]
-
- def _send_message_to_scheduler(
- self,
- user_id: str,
- mem_cube_id: str,
- query: str,
- label: str,
- ):
- """
- Send message to scheduler.
- args:
- user_id: str,
- mem_cube_id: str,
- query: str,
- """
-
- if self.enable_mem_scheduler and (self.mem_scheduler is not None):
- message_item = ScheduleMessageItem(
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- label=label,
- content=query,
- timestamp=datetime.utcnow(),
- )
- self.mem_scheduler.submit_messages(messages=[message_item])
-
- async def _post_chat_processing(
- self,
- user_id: str,
- cube_id: str,
- query: str,
- full_response: str,
- system_prompt: str,
- time_start: float,
- time_end: float,
- speed_improvement: float,
- current_messages: list,
- ) -> None:
- """
- Asynchronous processing of logs, notifications and memory additions
- """
- try:
- logger.info(
- f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}"
- )
- logger.info(f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}")
-
- clean_response, extracted_references = self._extract_references_from_response(
- full_response
- )
- struct_message = self._extract_struct_data_from_history(current_messages)
- logger.info(f"Extracted {len(extracted_references)} references from response")
-
- # Send chat report notifications asynchronously
- if self.online_bot:
- logger.info("Online Bot Open!")
- try:
- from memos.memos_tools.notification_utils import (
- send_online_bot_notification_async,
- )
-
- # Prepare notification data
- chat_data = {"query": query, "user_id": user_id, "cube_id": cube_id}
- chat_data.update(
- {
- "memory": struct_message["memory"],
- "chat_history": struct_message["chat_history"],
- "full_response": full_response,
- }
- )
-
- system_data = {
- "references": extracted_references,
- "time_start": time_start,
- "time_end": time_end,
- "speed_improvement": speed_improvement,
- }
-
- emoji_config = {"chat": "๐ฌ", "system_info": "๐"}
-
- await send_online_bot_notification_async(
- online_bot=self.online_bot,
- header_name="MemOS Chat Report",
- sub_title_name="chat_with_references",
- title_color="#00956D",
- other_data1=chat_data,
- other_data2=system_data,
- emoji=emoji_config,
- )
- except Exception as e:
- logger.warning(f"Failed to send chat notification (async): {e}")
-
- self._send_message_to_scheduler(
- user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_TASK_LABEL
- )
-
- self.add(
- user_id=user_id,
- messages=[
- {
- "role": "user",
- "content": query,
- "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
- },
- {
- "role": "assistant",
- "content": clean_response, # Store clean text without reference markers
- "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
- },
- ],
- mem_cube_id=cube_id,
- )
-
- logger.info(f"Post-chat processing completed for user {user_id}")
-
- except Exception as e:
- logger.error(f"Error in post-chat processing for user {user_id}: {e}", exc_info=True)
-
- def _start_post_chat_processing(
- self,
- user_id: str,
- cube_id: str,
- query: str,
- full_response: str,
- system_prompt: str,
- time_start: float,
- time_end: float,
- speed_improvement: float,
- current_messages: list,
- ) -> None:
- """
- Asynchronous processing of logs, notifications and memory additions, handle synchronous and asynchronous environments
- """
- logger.info("Start post_chat_processing...")
-
- def run_async_in_thread():
- """Running asynchronous tasks in a new thread"""
- try:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- self._post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- query=query,
- full_response=full_response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=speed_improvement,
- current_messages=current_messages,
- )
- )
- finally:
- loop.close()
- except Exception as e:
- logger.error(
- f"Error in thread-based post-chat processing for user {user_id}: {e}",
- exc_info=True,
- )
-
- try:
- # Try to get the current event loop
- asyncio.get_running_loop()
- # Create task and store reference to prevent garbage collection
- task = asyncio.create_task(
- self._post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- query=query,
- full_response=full_response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=speed_improvement,
- current_messages=current_messages,
- )
- )
- # Add exception handling for the background task
- task.add_done_callback(
- lambda t: (
- logger.error(
- f"Error in background post-chat processing for user {user_id}: {t.exception()}",
- exc_info=True,
- )
- if t.exception()
- else None
- )
- )
- except RuntimeError:
- # No event loop, run in a new thread with context propagation
- thread = ContextThread(
- target=run_async_in_thread,
- name=f"PostChatProcessing-{user_id}",
- # Set as a daemon thread to avoid blocking program exit
- daemon=True,
- )
- thread.start()
-
- def _filter_memories_by_threshold(
- self,
- memories: list[TextualMemoryItem],
- threshold: float = 0.30,
- min_num: int = 3,
- memory_type: Literal["OuterMemory"] = "OuterMemory",
- ) -> list[TextualMemoryItem]:
- """
- Filter memories by threshold and type, at least min_num memories for Non-OuterMemory.
- Args:
- memories: list[TextualMemoryItem],
- threshold: float,
- min_num: int,
- memory_type: Literal["OuterMemory"],
- Returns:
- list[TextualMemoryItem]
- """
- sorted_memories = sorted(memories, key=lambda m: m.metadata.relativity, reverse=True)
- filtered_person = [m for m in memories if m.metadata.memory_type != memory_type]
- filtered_outer = [m for m in memories if m.metadata.memory_type == memory_type]
- filtered = []
- per_memory_count = 0
- for m in sorted_memories:
- if m.metadata.relativity >= threshold:
- if m.metadata.memory_type != memory_type:
- per_memory_count += 1
- filtered.append(m)
- if len(filtered) < min_num:
- filtered = filtered_person[:min_num] + filtered_outer[:min_num]
- else:
- if per_memory_count < min_num:
- filtered += filtered_person[per_memory_count:min_num]
- filtered_memory = sorted(filtered, key=lambda m: m.metadata.relativity, reverse=True)
- return filtered_memory
-
- def register_mem_cube(
- self,
- mem_cube_name_or_path_or_object: str | GeneralMemCube,
- mem_cube_id: str | None = None,
- user_id: str | None = None,
- memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None,
- default_config: GeneralMemCubeConfig | None = None,
- ) -> None:
- """
- Register a MemCube with the MOS.
-
- Args:
- mem_cube_name_or_path_or_object (str | GeneralMemCube): The name, path, or GeneralMemCube object to register.
- mem_cube_id (str, optional): The identifier for the MemCube. If not provided, a default ID is used.
- user_id (str, optional): The user ID to register the cube for.
- memory_types (list[str], optional): List of memory types to load.
- If None, loads all available memory types.
- Options: ["text_mem", "act_mem", "para_mem"]
- default_config (GeneralMemCubeConfig, optional): Default configuration for the cube.
- """
- # Handle different input types
- if isinstance(mem_cube_name_or_path_or_object, GeneralMemCube):
- # Direct GeneralMemCube object provided
- mem_cube = mem_cube_name_or_path_or_object
- if mem_cube_id is None:
- mem_cube_id = f"cube_{id(mem_cube)}" # Generate a unique ID
- else:
- # String path provided
- mem_cube_name_or_path = mem_cube_name_or_path_or_object
- if mem_cube_id is None:
- mem_cube_id = mem_cube_name_or_path
-
- if mem_cube_id in self.mem_cubes:
- logger.info(f"MemCube with ID {mem_cube_id} already in MOS, skip install.")
- return
-
- # Create MemCube from path
- time_start = time.time()
- if os.path.exists(mem_cube_name_or_path):
- mem_cube = GeneralMemCube.init_from_dir(
- mem_cube_name_or_path, memory_types, default_config
- )
- logger.info(
- f"time register_mem_cube: init_from_dir time is: {time.time() - time_start}"
- )
- else:
- logger.warning(
- f"MemCube {mem_cube_name_or_path} does not exist, try to init from remote repo."
- )
- mem_cube = GeneralMemCube.init_from_remote_repo(
- mem_cube_name_or_path, memory_types=memory_types, default_config=default_config
- )
-
- # Register the MemCube
- logger.info(
- f"Registering MemCube {mem_cube_id} with cube config {mem_cube.config.model_dump(mode='json')}"
- )
- time_start = time.time()
- self.mem_cubes[mem_cube_id] = mem_cube
- time_end = time.time()
- logger.info(f"time register_mem_cube: add mem_cube time is: {time_end - time_start}")
-
- def user_register(
- self,
- user_id: str,
- user_name: str | None = None,
- config: MOSConfig | None = None,
- interests: str | None = None,
- default_mem_cube: GeneralMemCube | None = None,
- default_cube_config: GeneralMemCubeConfig | None = None,
- mem_cube_id: str | None = None,
- ) -> dict[str, str]:
- """Register a new user with configuration and default cube.
-
- Args:
- user_id (str): The user ID for registration.
- user_name (str): The user name for registration.
- config (MOSConfig | None, optional): User-specific configuration. Defaults to None.
- interests (str | None, optional): User interests as string. Defaults to None.
- default_mem_cube (GeneralMemCube | None, optional): Default memory cube. Defaults to None.
- default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None.
-
- Returns:
- dict[str, str]: Registration result with status and message.
- """
- try:
- # Use provided config or default config
- user_config = config or self.default_config
- if not user_config:
- return {
- "status": "error",
- "message": "No configuration provided for user registration",
- }
- if not user_name:
- user_name = user_id
-
- # Create user with configuration using persistent user manager
- self.user_manager.create_user_with_config(user_id, user_config, UserRole.USER, user_id)
-
- # Create user configuration
- user_config = self._create_user_config(user_id, user_config)
-
- # Create a default cube for the user using MOSCore's methods
- default_cube_name = f"{user_name}_{user_id}_default_cube"
- mem_cube_name_or_path = os.path.join(CUBE_PATH, default_cube_name)
- default_cube_id = self.create_cube_for_user(
- cube_name=default_cube_name,
- owner_id=user_id,
- cube_path=mem_cube_name_or_path,
- cube_id=mem_cube_id,
- )
- time_start = time.time()
- if default_mem_cube:
- try:
- default_mem_cube.dump(mem_cube_name_or_path, memory_types=[])
- except Exception as e:
- logger.error(f"Failed to dump default cube: {e}")
- time_end = time.time()
- logger.info(f"time user_register: dump default cube time is: {time_end - time_start}")
- # Register the default cube with MOS
- self.register_mem_cube(
- mem_cube_name_or_path_or_object=default_mem_cube,
- mem_cube_id=default_cube_id,
- user_id=user_id,
- memory_types=["act_mem"] if self.config.enable_activation_memory else [],
- default_config=default_cube_config, # use default cube config
- )
-
- # Add interests to the default cube if provided
- if interests:
- self.add(memory_content=interests, mem_cube_id=default_cube_id, user_id=user_id)
-
- return {
- "status": "success",
- "message": f"User {user_name} registered successfully with default cube {default_cube_id}",
- "user_id": user_id,
- "default_cube_id": default_cube_id,
- }
-
- except Exception as e:
- return {"status": "error", "message": f"Failed to register user: {e!s}"}
-
- def _get_further_suggestion(self, message: MessageList | None = None) -> list[str]:
- """Get further suggestion prompt."""
- try:
- dialogue_info = "\n".join([f"{msg['role']}: {msg['content']}" for msg in message[-2:]])
- further_suggestion_prompt = FURTHER_SUGGESTION_PROMPT.format(dialogue=dialogue_info)
- message_list = [{"role": "system", "content": further_suggestion_prompt}]
- response = self.chat_llm.generate(message_list)
- clean_response = clean_json_response(response)
- response_json = json.loads(clean_response)
- return response_json["query"]
- except Exception as e:
- logger.error(f"Error getting further suggestion: {e}", exc_info=True)
- return []
-
- def get_suggestion_query(
- self, user_id: str, language: str = "zh", message: MessageList | None = None
- ) -> list[str]:
- """Get suggestion query from LLM.
- Args:
- user_id (str): User ID.
- language (str): Language for suggestions ("zh" or "en").
-
- Returns:
- list[str]: The suggestion query list.
- """
- if message:
- further_suggestion = self._get_further_suggestion(message)
- return further_suggestion
- if language == "zh":
- suggestion_prompt = SUGGESTION_QUERY_PROMPT_ZH
- else: # English
- suggestion_prompt = SUGGESTION_QUERY_PROMPT_EN
- text_mem_result = super().search("my recently memories", user_id=user_id, top_k=3)[
- "text_mem"
- ]
- if text_mem_result:
- memories = "\n".join([m.memory[:200] for m in text_mem_result[0]["memories"]])
- else:
- memories = ""
- message_list = [{"role": "system", "content": suggestion_prompt.format(memories=memories)}]
- response = self.chat_llm.generate(message_list)
- clean_response = clean_json_response(response)
- response_json = json.loads(clean_response)
- return response_json["query"]
-
- def chat(
- self,
- query: str,
- user_id: str,
- cube_id: str | None = None,
- history: MessageList | None = None,
- base_prompt: str | None = None,
- internet_search: bool = False,
- moscube: bool = False,
- top_k: int = 10,
- threshold: float = 0.5,
- session_id: str | None = None,
- ) -> str:
- """
- Chat with LLM with memory references and complete response.
- """
- self._load_user_cubes(user_id, self.default_cube_config)
- time_start = time.time()
- memories_result = super().search(
- query,
- user_id,
- install_cube_ids=[cube_id] if cube_id else None,
- top_k=top_k,
- mode="fine",
- internet_search=internet_search,
- moscube=moscube,
- session_id=session_id,
- )["text_mem"]
-
- memories_list = []
- if memories_result:
- memories_list = memories_result[0]["memories"]
- memories_list = self._filter_memories_by_threshold(memories_list, threshold)
- new_memories_list = []
- for m in memories_list:
- m.metadata.embedding = []
- new_memories_list.append(m)
- memories_list = new_memories_list
-
- system_prompt = super()._build_system_prompt(memories_list, base_prompt)
- if history is not None:
- # Use the provided history (even if it's empty)
- history_info = history[-20:]
- else:
- # Fall back to internal chat_history
- if user_id not in self.chat_history_manager:
- self._register_chat_history(user_id, session_id)
- history_info = self.chat_history_manager[user_id].chat_history[-20:]
- current_messages = [
- {"role": "system", "content": system_prompt},
- *history_info,
- {"role": "user", "content": query},
- ]
- logger.info("Start to get final answer...")
- response = self.chat_llm.generate(current_messages)
- time_end = time.time()
- self._start_post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- query=query,
- full_response=response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=0.0,
- current_messages=current_messages,
- )
- return response, memories_list
-
- def chat_with_references(
- self,
- query: str,
- user_id: str,
- cube_id: str | None = None,
- history: MessageList | None = None,
- top_k: int = 20,
- internet_search: bool = False,
- moscube: bool = False,
- session_id: str | None = None,
- ) -> Generator[str, None, None]:
- """
- Chat with LLM with memory references and streaming output.
-
- Args:
- query (str): Query string.
- user_id (str): User ID.
- cube_id (str, optional): Custom cube ID for user.
- history (MessageList, optional): Chat history.
-
- Returns:
- Generator[str, None, None]: The response string generator with reference processing.
- """
-
- self._load_user_cubes(user_id, self.default_cube_config)
- time_start = time.time()
- memories_list = []
- yield f"data: {json.dumps({'type': 'status', 'data': '0'})}\n\n"
- memories_result = super().search(
- query,
- user_id,
- install_cube_ids=[cube_id] if cube_id else None,
- top_k=top_k,
- mode="fine",
- internet_search=internet_search,
- moscube=moscube,
- session_id=session_id,
- )["text_mem"]
-
- yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
- search_time_end = time.time()
- logger.info(
- f"time chat: search text_mem time user_id: {user_id} time is: {search_time_end - time_start}"
- )
- self._send_message_to_scheduler(
- user_id=user_id, mem_cube_id=cube_id, query=query, label=QUERY_TASK_LABEL
- )
- if memories_result:
- memories_list = memories_result[0]["memories"]
- memories_list = self._filter_memories_by_threshold(memories_list)
-
- reference = prepare_reference_data(memories_list)
- yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
- # Build custom system prompt with relevant memories)
- system_prompt = self._build_enhance_system_prompt(user_id, memories_list)
- # Get chat history
- if user_id not in self.chat_history_manager:
- self._register_chat_history(user_id, session_id)
-
- chat_history = self.chat_history_manager[user_id]
- if history is not None:
- chat_history.chat_history = history[-20:]
- current_messages = [
- {"role": "system", "content": system_prompt},
- *chat_history.chat_history,
- {"role": "user", "content": query},
- ]
- logger.info(
- f"user_id: {user_id}, cube_id: {cube_id}, current_system_prompt: {system_prompt}"
- )
- yield f"data: {json.dumps({'type': 'status', 'data': '2'})}\n\n"
- # Generate response with custom prompt
- past_key_values = None
- response_stream = None
- if self.config.enable_activation_memory:
- # Handle activation memory (copy MOSCore logic)
- for mem_cube_id, mem_cube in self.mem_cubes.items():
- if mem_cube.act_mem and mem_cube_id == cube_id:
- kv_cache = next(iter(mem_cube.act_mem.get_all()), None)
- past_key_values = (
- kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None
- )
- if past_key_values is not None:
- logger.info("past_key_values is not None will apply to chat")
- else:
- logger.info("past_key_values is None will not apply to chat")
- break
- if self.config.chat_model.backend == "huggingface":
- response_stream = self.chat_llm.generate_stream(
- current_messages, past_key_values=past_key_values
- )
- elif self.config.chat_model.backend == "vllm":
- response_stream = self.chat_llm.generate_stream(current_messages)
- else:
- if self.config.chat_model.backend in ["huggingface", "vllm", "openai"]:
- response_stream = self.chat_llm.generate_stream(current_messages)
- else:
- response_stream = self.chat_llm.generate(current_messages)
-
- time_end = time.time()
- chat_time_end = time.time()
- logger.info(
- f"time chat: chat time user_id: {user_id} time is: {chat_time_end - search_time_end}"
- )
- # Simulate streaming output with proper reference handling using tiktoken
-
- # Initialize buffer for streaming
- buffer = ""
- full_response = ""
- token_count = 0
- # Use tiktoken for proper token-based chunking
- if self.config.chat_model.backend not in ["huggingface", "vllm", "openai"]:
- # For non-huggingface backends, we need to collect the full response first
- full_response_text = ""
- for chunk in response_stream:
- if chunk in ["", ""]:
- continue
- full_response_text += chunk
- response_stream = self._chunk_response_with_tiktoken(full_response_text, chunk_size=5)
- for chunk in response_stream:
- if chunk in ["", ""]:
- continue
- token_count += 1
- buffer += chunk
- full_response += chunk
-
- # Process buffer to ensure complete reference tags
- processed_chunk, remaining_buffer = process_streaming_references_complete(buffer)
-
- if processed_chunk:
- chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n"
- yield chunk_data
- buffer = remaining_buffer
-
- # Process any remaining buffer
- if buffer:
- processed_chunk, remaining_buffer = process_streaming_references_complete(buffer)
- if processed_chunk:
- chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n"
- yield chunk_data
-
- # set kvcache improve speed
- speed_improvement = round(float((len(system_prompt) / 2) * 0.0048 + 44.5), 1)
- total_time = round(float(time_end - time_start), 1)
-
- yield f"data: {json.dumps({'type': 'time', 'data': {'total_time': total_time, 'speed_improvement': f'{speed_improvement}%'}})}\n\n"
- # get further suggestion
- current_messages.append({"role": "assistant", "content": full_response})
- further_suggestion = self._get_further_suggestion(current_messages)
- logger.info(f"further_suggestion: {further_suggestion}")
- yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion})}\n\n"
- yield f"data: {json.dumps({'type': 'end'})}\n\n"
-
- # Asynchronous processing of logs, notifications and memory additions
- self._start_post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- query=query,
- full_response=full_response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=speed_improvement,
- current_messages=current_messages,
- )
-
- def get_all(
- self,
- user_id: str,
- memory_type: Literal["text_mem", "act_mem", "param_mem", "para_mem"],
- mem_cube_ids: list[str] | None = None,
- ) -> list[dict[str, Any]]:
- """Get all memory items for a user.
-
- Args:
- user_id (str): The ID of the user.
- cube_id (str | None, optional): The ID of the cube. Defaults to None.
- memory_type (Literal["text_mem", "act_mem", "param_mem"]): The type of memory to get.
-
- Returns:
- list[dict[str, Any]]: A list of memory items with cube_id and memories structure.
- """
-
- # Load user cubes if not already loaded
- self._load_user_cubes(user_id, self.default_cube_config)
- time_start = time.time()
- memory_list = super().get_all(
- mem_cube_id=mem_cube_ids[0] if mem_cube_ids else None, user_id=user_id
- )[memory_type]
- get_all_time_end = time.time()
- logger.info(
- f"time get_all: get_all time user_id: {user_id} time is: {get_all_time_end - time_start}"
- )
- reformat_memory_list = []
- if memory_type == "text_mem":
- for memory in memory_list:
- memories = remove_embedding_recursive(memory["memories"])
- custom_type_ratios = {
- "WorkingMemory": 0.20,
- "LongTermMemory": 0.40,
- "UserMemory": 0.40,
- }
- tree_result, node_type_count = convert_graph_to_tree_forworkmem(
- memories, target_node_count=200, type_ratios=custom_type_ratios
- )
- # Ensure all node IDs are unique in the tree structure
- tree_result = ensure_unique_tree_ids(tree_result)
- memories_filtered = filter_nodes_by_tree_ids(tree_result, memories)
- children = tree_result["children"]
- children_sort = sort_children_by_memory_type(children)
- tree_result["children"] = children_sort
- memories_filtered["tree_structure"] = tree_result
- reformat_memory_list.append(
- {
- "cube_id": memory["cube_id"],
- "memories": [memories_filtered],
- "memory_statistics": node_type_count,
- }
- )
- elif memory_type == "act_mem":
- memories_list = []
- act_mem_params = self.mem_cubes[mem_cube_ids[0]].act_mem.get_all()
- if act_mem_params:
- memories_data = act_mem_params[0].model_dump()
- records = memories_data.get("records", [])
- for record in records["text_memories"]:
- memories_list.append(
- {
- "id": memories_data["id"],
- "text": record,
- "create_time": records["timestamp"],
- "size": random.randint(1, 20),
- "modify_times": 1,
- }
- )
- reformat_memory_list.append(
- {
- "cube_id": "xxxxxxxxxxxxxxxx" if not mem_cube_ids else mem_cube_ids[0],
- "memories": memories_list,
- }
- )
- elif memory_type == "para_mem":
- act_mem_params = self.mem_cubes[mem_cube_ids[0]].act_mem.get_all()
- logger.info(f"act_mem_params: {act_mem_params}")
- reformat_memory_list.append(
- {
- "cube_id": "xxxxxxxxxxxxxxxx" if not mem_cube_ids else mem_cube_ids[0],
- "memories": act_mem_params[0].model_dump(),
- }
- )
- make_format_time_end = time.time()
- logger.info(
- f"time get_all: make_format time user_id: {user_id} time is: {make_format_time_end - get_all_time_end}"
- )
- return reformat_memory_list
-
- def _get_subgraph(
- self, query: str, mem_cube_id: str, user_id: str | None = None, top_k: int = 5
- ) -> list[dict[str, Any]]:
- result = {"para_mem": [], "act_mem": [], "text_mem": []}
- if self.config.enable_textual_memory and self.mem_cubes[mem_cube_id].text_mem:
- result["text_mem"].append(
- {
- "cube_id": mem_cube_id,
- "memories": self.mem_cubes[mem_cube_id].text_mem.get_relevant_subgraph(
- query, top_k=top_k
- ),
- }
- )
- return result
-
- def get_subgraph(
- self,
- user_id: str,
- query: str,
- mem_cube_ids: list[str] | None = None,
- top_k: int = 20,
- ) -> list[dict[str, Any]]:
- """Get all memory items for a user.
-
- Args:
- user_id (str): The ID of the user.
- cube_id (str | None, optional): The ID of the cube. Defaults to None.
- mem_cube_ids (list[str], optional): The IDs of the cubes. Defaults to None.
-
- Returns:
- list[dict[str, Any]]: A list of memory items with cube_id and memories structure.
- """
-
- # Load user cubes if not already loaded
- self._load_user_cubes(user_id, self.default_cube_config)
- memory_list = self._get_subgraph(
- query=query, mem_cube_id=mem_cube_ids[0], user_id=user_id, top_k=top_k
- )["text_mem"]
- reformat_memory_list = []
- for memory in memory_list:
- memories = remove_embedding_recursive(memory["memories"])
- custom_type_ratios = {"WorkingMemory": 0.20, "LongTermMemory": 0.40, "UserMemory": 0.4}
- tree_result, node_type_count = convert_graph_to_tree_forworkmem(
- memories, target_node_count=150, type_ratios=custom_type_ratios
- )
- # Ensure all node IDs are unique in the tree structure
- tree_result = ensure_unique_tree_ids(tree_result)
- memories_filtered = filter_nodes_by_tree_ids(tree_result, memories)
- children = tree_result["children"]
- children_sort = sort_children_by_memory_type(children)
- tree_result["children"] = children_sort
- memories_filtered["tree_structure"] = tree_result
- reformat_memory_list.append(
- {
- "cube_id": memory["cube_id"],
- "memories": [memories_filtered],
- "memory_statistics": node_type_count,
- }
- )
-
- return reformat_memory_list
-
- def search(
- self,
- query: str,
- user_id: str,
- install_cube_ids: list[str] | None = None,
- top_k: int = 10,
- mode: Literal["fast", "fine"] = "fast",
- session_id: str | None = None,
- ):
- """Search memories for a specific user."""
-
- # Load user cubes if not already loaded
- time_start = time.time()
- self._load_user_cubes(user_id, self.default_cube_config)
- load_user_cubes_time_end = time.time()
- logger.info(
- f"time search: load_user_cubes time user_id: {user_id} time is: {load_user_cubes_time_end - time_start}"
- )
- search_result = super().search(
- query, user_id, install_cube_ids, top_k, mode=mode, session_id=session_id
- )
- search_time_end = time.time()
- logger.info(
- f"time search: search text_mem time user_id: {user_id} time is: {search_time_end - load_user_cubes_time_end}"
- )
- text_memory_list = search_result["text_mem"]
- reformat_memory_list = []
- for memory in text_memory_list:
- memories_list = []
- for data in memory["memories"]:
- memories = data.model_dump()
- memories["ref_id"] = f"[{memories['id'].split('-')[0]}]"
- memories["metadata"]["embedding"] = []
- memories["metadata"]["sources"] = []
- memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]"
- memories["metadata"]["id"] = memories["id"]
- memories["metadata"]["memory"] = memories["memory"]
- memories_list.append(memories)
- reformat_memory_list.append({"cube_id": memory["cube_id"], "memories": memories_list})
- logger.info(f"search memory list is : {reformat_memory_list}")
- search_result["text_mem"] = reformat_memory_list
-
- pref_memory_list = search_result["pref_mem"]
- reformat_pref_memory_list = []
- for memory in pref_memory_list:
- memories_list = []
- for data in memory["memories"]:
- memories = data.model_dump()
- memories["ref_id"] = f"[{memories['id'].split('-')[0]}]"
- memories["metadata"]["embedding"] = []
- memories["metadata"]["sources"] = []
- memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]"
- memories["metadata"]["id"] = memories["id"]
- memories["metadata"]["memory"] = memories["memory"]
- memories_list.append(memories)
- reformat_pref_memory_list.append(
- {"cube_id": memory["cube_id"], "memories": memories_list}
- )
- search_result["pref_mem"] = reformat_pref_memory_list
- time_end = time.time()
- logger.info(
- f"time search: total time for user_id: {user_id} time is: {time_end - time_start}"
- )
- return search_result
-
- def add(
- self,
- user_id: str,
- messages: MessageList | None = None,
- memory_content: str | None = None,
- doc_path: str | None = None,
- mem_cube_id: str | None = None,
- source: str | None = None,
- user_profile: bool = False,
- session_id: str | None = None,
- task_id: str | None = None, # Add task_id parameter
- ):
- """Add memory for a specific user."""
-
- # Load user cubes if not already loaded
- self._load_user_cubes(user_id, self.default_cube_config)
- result = super().add(
- messages,
- memory_content,
- doc_path,
- mem_cube_id,
- user_id,
- session_id=session_id,
- task_id=task_id,
- )
- if user_profile:
- try:
- user_interests = memory_content.split("'userInterests': '")[1].split("', '")[0]
- user_interests = user_interests.replace(",", " ")
- user_profile_memories = self.mem_cubes[
- mem_cube_id
- ].text_mem.internet_retriever.retrieve_from_internet(query=user_interests, top_k=5)
- for memory in user_profile_memories:
- self.mem_cubes[mem_cube_id].text_mem.add(memory)
- except Exception as e:
- logger.error(
- f"Failed to retrieve user profile: {e}, memory_content: {memory_content}"
- )
-
- return result
-
- def list_users(self) -> list:
- """List all registered users."""
- return self.user_manager.list_users()
-
- def get_user_info(self, user_id: str) -> dict:
- """Get user information including accessible cubes."""
- # Use MOSCore's built-in user validation
- # Validate user access
- self._validate_user_access(user_id)
-
- result = super().get_user_info()
-
- return result
-
- def share_cube_with_user(self, cube_id: str, owner_user_id: str, target_user_id: str) -> bool:
- """Share a cube with another user."""
- # Use MOSCore's built-in cube access validation
- self._validate_cube_access(owner_user_id, cube_id)
-
- result = super().share_cube_with_user(cube_id, target_user_id)
-
- return result
-
- def clear_user_chat_history(self, user_id: str) -> None:
- """Clear chat history for a specific user."""
- # Validate user access
- self._validate_user_access(user_id)
-
- super().clear_messages(user_id)
-
- def update_user_config(self, user_id: str, config: MOSConfig) -> bool:
- """Update user configuration.
-
- Args:
- user_id (str): The user ID.
- config (MOSConfig): The new configuration.
-
- Returns:
- bool: True if successful, False otherwise.
- """
- try:
- # Save to persistent storage
- success = self.user_manager.save_user_config(user_id, config)
- if success:
- # Update in-memory config
- self.user_configs[user_id] = config
- logger.info(f"Updated configuration for user {user_id}")
-
- return success
- except Exception as e:
- logger.error(f"Failed to update user config for {user_id}: {e}")
- return False
-
- def get_user_config(self, user_id: str) -> MOSConfig | None:
- """Get user configuration.
-
- Args:
- user_id (str): The user ID.
-
- Returns:
- MOSConfig | None: The user's configuration or None if not found.
- """
- return self.user_manager.get_user_config(user_id)
-
- def get_active_user_count(self) -> int:
- """Get the number of active user configurations in memory."""
- return len(self.user_configs)
-
- def get_user_instance_info(self) -> dict[str, Any]:
- """Get information about user configurations in memory."""
- return {
- "active_instances": len(self.user_configs),
- "max_instances": self.max_user_instances,
- "user_ids": list(self.user_configs.keys()),
- "lru_order": list(self.user_configs.keys()), # OrderedDict maintains insertion order
- }
diff --git a/src/memos/mem_os/product_server.py b/src/memos/mem_os/product_server.py
deleted file mode 100644
index 80aefea85..000000000
--- a/src/memos/mem_os/product_server.py
+++ /dev/null
@@ -1,457 +0,0 @@
-import asyncio
-import time
-
-from datetime import datetime
-from typing import Literal
-
-from memos.context.context import ContextThread
-from memos.llms.base import BaseLLM
-from memos.log import get_logger
-from memos.mem_cube.navie import NaiveMemCube
-from memos.mem_os.product import _format_mem_block
-from memos.mem_reader.base import BaseMemReader
-from memos.memories.textual.item import TextualMemoryItem
-from memos.templates.mos_prompts import (
- get_memos_prompt,
-)
-from memos.types import MessageList
-
-
-logger = get_logger(__name__)
-
-
-class MOSServer:
- def __init__(
- self,
- mem_reader: BaseMemReader | None = None,
- llm: BaseLLM | None = None,
- online_bot: bool = False,
- ):
- self.mem_reader = mem_reader
- self.chat_llm = llm
- self.online_bot = online_bot
-
- def chat(
- self,
- query: str,
- user_id: str,
- cube_id: str | None = None,
- mem_cube: NaiveMemCube | None = None,
- history: MessageList | None = None,
- base_prompt: str | None = None,
- internet_search: bool = False,
- moscube: bool = False,
- top_k: int = 10,
- threshold: float = 0.5,
- session_id: str | None = None,
- ) -> str:
- """
- Chat with LLM with memory references and complete response.
- """
- time_start = time.time()
- memories_result = mem_cube.text_mem.search(
- query=query,
- user_name=cube_id,
- top_k=top_k,
- mode="fine",
- manual_close_internet=not internet_search,
- moscube=moscube,
- info={
- "user_id": user_id,
- "session_id": session_id,
- "chat_history": history,
- },
- )
-
- memories_list = []
- if memories_result:
- memories_list = self._filter_memories_by_threshold(memories_result, threshold)
- new_memories_list = []
- for m in memories_list:
- m.metadata.embedding = []
- new_memories_list.append(m)
- memories_list = new_memories_list
- system_prompt = self._build_system_prompt(memories_list, base_prompt)
-
- history_info = []
- if history:
- history_info = history[-20:]
- current_messages = [
- {"role": "system", "content": system_prompt},
- *history_info,
- {"role": "user", "content": query},
- ]
- response = self.chat_llm.generate(current_messages)
- time_end = time.time()
- self._start_post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- session_id=session_id,
- query=query,
- full_response=response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=0.0,
- current_messages=current_messages,
- mem_cube=mem_cube,
- history=history,
- )
- return response, memories_list
-
- def add(
- self,
- user_id: str,
- cube_id: str,
- mem_cube: NaiveMemCube,
- messages: MessageList,
- session_id: str | None = None,
- history: MessageList | None = None,
- ) -> list[str]:
- memories = self.mem_reader.get_memory(
- [messages],
- type="chat",
- info={
- "user_id": user_id,
- "session_id": session_id,
- "chat_history": history,
- },
- )
- flattened_memories = [mm for m in memories for mm in m]
- mem_id_list: list[str] = mem_cube.text_mem.add(
- flattened_memories,
- user_name=cube_id,
- )
- return mem_id_list
-
- def search(
- self,
- user_id: str,
- cube_id: str,
- session_id: str | None = None,
- ) -> None:
- NotImplementedError("Not implemented")
-
- def _filter_memories_by_threshold(
- self,
- memories: list[TextualMemoryItem],
- threshold: float = 0.30,
- min_num: int = 3,
- memory_type: Literal["OuterMemory"] = "OuterMemory",
- ) -> list[TextualMemoryItem]:
- """
- Filter memories by threshold and type, at least min_num memories for Non-OuterMemory.
- Args:
- memories: list[TextualMemoryItem],
- threshold: float,
- min_num: int,
- memory_type: Literal["OuterMemory"],
- Returns:
- list[TextualMemoryItem]
- """
- sorted_memories = sorted(memories, key=lambda m: m.metadata.relativity, reverse=True)
- filtered_person = [m for m in memories if m.metadata.memory_type != memory_type]
- filtered_outer = [m for m in memories if m.metadata.memory_type == memory_type]
- filtered = []
- per_memory_count = 0
- for m in sorted_memories:
- if m.metadata.relativity >= threshold:
- if m.metadata.memory_type != memory_type:
- per_memory_count += 1
- filtered.append(m)
- if len(filtered) < min_num:
- filtered = filtered_person[:min_num] + filtered_outer[:min_num]
- else:
- if per_memory_count < min_num:
- filtered += filtered_person[per_memory_count:min_num]
- filtered_memory = sorted(filtered, key=lambda m: m.metadata.relativity, reverse=True)
- return filtered_memory
-
- def _build_base_system_prompt(
- self,
- base_prompt: str | None = None,
- tone: str = "friendly",
- verbosity: str = "mid",
- mode: str = "enhance",
- ) -> str:
- """
- Build base system prompt without memory references.
- """
- now = datetime.now()
- formatted_date = now.strftime("%Y-%m-%d (%A)")
- sys_body = get_memos_prompt(date=formatted_date, tone=tone, verbosity=verbosity, mode=mode)
- prefix = (base_prompt.strip() + "\n\n") if base_prompt else ""
- return prefix + sys_body
-
- def _build_system_prompt(
- self,
- memories: list[TextualMemoryItem] | list[str] | None = None,
- base_prompt: str | None = None,
- **kwargs,
- ) -> str:
- """Build system prompt with optional memories context."""
- if base_prompt is None:
- base_prompt = (
- "You are a knowledgeable and helpful AI assistant. "
- "You have access to conversation memories that help you provide more personalized responses. "
- "Use the memories to understand the user's context, preferences, and past interactions. "
- "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories."
- )
-
- memory_context = ""
- if memories:
- memory_list = []
- for i, memory in enumerate(memories, 1):
- if isinstance(memory, TextualMemoryItem):
- text_memory = memory.memory
- else:
- if not isinstance(memory, str):
- logger.error("Unexpected memory type.")
- text_memory = memory
- memory_list.append(f"{i}. {text_memory}")
- memory_context = "\n".join(memory_list)
-
- if "{memories}" in base_prompt:
- return base_prompt.format(memories=memory_context)
- elif base_prompt and memories:
- # For backward compatibility, append memories if no placeholder is found
- memory_context_with_header = "\n\n## Memories:\n" + memory_context
- return base_prompt + memory_context_with_header
- return base_prompt
-
- def _build_memory_context(
- self,
- memories_all: list[TextualMemoryItem],
- mode: str = "enhance",
- ) -> str:
- """
- Build memory context to be included in user message.
- """
- if not memories_all:
- return ""
-
- mem_block_o, mem_block_p = _format_mem_block(memories_all)
-
- if mode == "enhance":
- return (
- "# Memories\n## PersonalMemory (ordered)\n"
- + mem_block_p
- + "\n## OuterMemory (ordered)\n"
- + mem_block_o
- + "\n\n"
- )
- else:
- mem_block = mem_block_o + "\n" + mem_block_p
- return "# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + mem_block + "\n\n"
-
- def _extract_references_from_response(self, response: str) -> tuple[str, list[dict]]:
- """
- Extract reference information from the response and return clean text.
-
- Args:
- response (str): The complete response text.
-
- Returns:
- tuple[str, list[dict]]: A tuple containing:
- - clean_text: Text with reference markers removed
- - references: List of reference information
- """
- import re
-
- try:
- references = []
- # Pattern to match [refid:memoriesID]
- pattern = r"\[(\d+):([^\]]+)\]"
-
- matches = re.findall(pattern, response)
- for ref_number, memory_id in matches:
- references.append({"memory_id": memory_id, "reference_number": int(ref_number)})
-
- # Remove all reference markers from the text to get clean text
- clean_text = re.sub(pattern, "", response)
-
- # Clean up any extra whitespace that might be left after removing markers
- clean_text = re.sub(r"\s+", " ", clean_text).strip()
-
- return clean_text, references
- except Exception as e:
- logger.error(f"Error extracting references from response: {e}", exc_info=True)
- return response, []
-
- async def _post_chat_processing(
- self,
- user_id: str,
- cube_id: str,
- query: str,
- full_response: str,
- system_prompt: str,
- time_start: float,
- time_end: float,
- speed_improvement: float,
- current_messages: list,
- mem_cube: NaiveMemCube | None = None,
- session_id: str | None = None,
- history: MessageList | None = None,
- ) -> None:
- """
- Asynchronous processing of logs, notifications and memory additions
- """
- try:
- logger.info(
- f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}"
- )
- logger.info(f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}")
-
- clean_response, extracted_references = self._extract_references_from_response(
- full_response
- )
- logger.info(f"Extracted {len(extracted_references)} references from response")
-
- # Send chat report notifications asynchronously
- if self.online_bot:
- try:
- from memos.memos_tools.notification_utils import (
- send_online_bot_notification_async,
- )
-
- # Prepare notification data
- chat_data = {
- "query": query,
- "user_id": user_id,
- "cube_id": cube_id,
- "system_prompt": system_prompt,
- "full_response": full_response,
- }
-
- system_data = {
- "references": extracted_references,
- "time_start": time_start,
- "time_end": time_end,
- "speed_improvement": speed_improvement,
- }
-
- emoji_config = {"chat": "๐ฌ", "system_info": "๐"}
-
- await send_online_bot_notification_async(
- online_bot=self.online_bot,
- header_name="MemOS Chat Report",
- sub_title_name="chat_with_references",
- title_color="#00956D",
- other_data1=chat_data,
- other_data2=system_data,
- emoji=emoji_config,
- )
- except Exception as e:
- logger.warning(f"Failed to send chat notification (async): {e}")
-
- self.add(
- user_id=user_id,
- cube_id=cube_id,
- mem_cube=mem_cube,
- session_id=session_id,
- history=history,
- messages=[
- {
- "role": "user",
- "content": query,
- "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
- },
- {
- "role": "assistant",
- "content": clean_response, # Store clean text without reference markers
- "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
- },
- ],
- )
-
- logger.info(f"Post-chat processing completed for user {user_id}")
-
- except Exception as e:
- logger.error(f"Error in post-chat processing for user {user_id}: {e}", exc_info=True)
-
- def _start_post_chat_processing(
- self,
- user_id: str,
- cube_id: str,
- query: str,
- full_response: str,
- system_prompt: str,
- time_start: float,
- time_end: float,
- speed_improvement: float,
- current_messages: list,
- mem_cube: NaiveMemCube | None = None,
- session_id: str | None = None,
- history: MessageList | None = None,
- ) -> None:
- """
- Asynchronous processing of logs, notifications and memory additions, handle synchronous and asynchronous environments
- """
-
- def run_async_in_thread():
- """Running asynchronous tasks in a new thread"""
- try:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- self._post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- query=query,
- full_response=full_response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=speed_improvement,
- current_messages=current_messages,
- mem_cube=mem_cube,
- session_id=session_id,
- history=history,
- )
- )
- finally:
- loop.close()
- except Exception as e:
- logger.error(
- f"Error in thread-based post-chat processing for user {user_id}: {e}",
- exc_info=True,
- )
-
- try:
- # Try to get the current event loop
- asyncio.get_running_loop()
- # Create task and store reference to prevent garbage collection
- task = asyncio.create_task(
- self._post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- query=query,
- full_response=full_response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=speed_improvement,
- current_messages=current_messages,
- )
- )
- # Add exception handling for the background task
- task.add_done_callback(
- lambda t: (
- logger.error(
- f"Error in background post-chat processing for user {user_id}: {t.exception()}",
- exc_info=True,
- )
- if t.exception()
- else None
- )
- )
- except RuntimeError:
- # No event loop, run in a new thread with context propagation
- thread = ContextThread(
- target=run_async_in_thread,
- name=f"PostChatProcessing-{user_id}",
- # Set as a daemon thread to avoid blocking program exit
- daemon=True,
- )
- thread.start()
diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
index ec431c253..671190e6f 100644
--- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
+++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
@@ -47,13 +47,14 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
graph_db_backend_map = {
"neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id),
"neo4j": APIConfig.get_neo4j_config(user_id=user_id),
- "nebular": APIConfig.get_nebular_config(user_id=user_id),
"polardb": APIConfig.get_polardb_config(user_id=user_id),
"postgres": APIConfig.get_postgres_config(user_id=user_id),
}
# Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars
- graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower()
+ graph_db_backend = os.getenv(
+ "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")
+ ).lower()
return GraphDBConfigFactory.model_validate(
{
"backend": graph_db_backend,
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py b/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py
index 3498f596a..4d122ca4e 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py
@@ -9,6 +9,7 @@
from memos.memories.textual.tree_text_memory.retrieve.internet_retriever import (
InternetGoogleRetriever,
)
+from memos.memories.textual.tree_text_memory.retrieve.tavilysearch import InternetTavilyRetriever
from memos.memories.textual.tree_text_memory.retrieve.xinyusearch import XinyuSearchRetriever
from memos.memos_tools.singleton import singleton_factory
@@ -21,6 +22,7 @@ class InternetRetrieverFactory:
"bing": InternetGoogleRetriever, # TODO: Implement BingRetriever
"xinyu": XinyuSearchRetriever,
"bocha": BochaAISearchRetriever,
+ "tavily": InternetTavilyRetriever,
}
@classmethod
@@ -81,6 +83,14 @@ def from_config(
reader=MemReaderFactory.from_config(config.reader),
max_results=config.max_results,
)
+ elif backend == "tavily":
+ return retriever_class(
+ api_key=config.api_key,
+ embedder=embedder,
+ max_results=config.max_results,
+ search_depth=config.search_depth,
+ include_answer=config.include_answer,
+ )
else:
raise ValueError(f"Unsupported backend: {backend}")
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/tavilysearch.py b/src/memos/memories/textual/tree_text_memory/retrieve/tavilysearch.py
new file mode 100644
index 000000000..fd7180b01
--- /dev/null
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/tavilysearch.py
@@ -0,0 +1,327 @@
+"""Tavily Search API retriever for tree text memory."""
+
+from concurrent.futures import as_completed
+from datetime import datetime
+from typing import Any
+
+from memos.context.context import ContextThreadPoolExecutor
+from memos.dependency import require_python_package
+from memos.embedders.factory import OllamaEmbedder
+from memos.log import get_logger
+from memos.mem_reader.read_multi_modal import detect_lang
+from memos.memories.textual.item import (
+ SearchedTreeNodeTextualMemoryMetadata,
+ SourceMessage,
+ TextualMemoryItem,
+)
+
+
+logger = get_logger(__name__)
+
+
+class InternetTavilyRetriever:
+ """Tavily retriever that converts search results into TextualMemoryItem objects"""
+
+ @require_python_package(
+ import_name="tavily",
+ install_command="pip install tavily-python",
+ install_link="https://github.com/tavily-ai/tavily-python",
+ )
+ def __init__(
+ self,
+ api_key: str,
+ embedder: OllamaEmbedder,
+ max_results: int = 10,
+ search_depth: str = "basic",
+ include_answer: bool = False,
+ ):
+ """
+ Initialize Tavily Search retriever.
+
+ Args:
+ api_key: Tavily API key
+ embedder: Embedder instance for generating embeddings
+ max_results: Maximum number of search results to retrieve
+ search_depth: Search depth ('basic' or 'advanced')
+ include_answer: Whether to include an AI-generated answer
+ """
+ from tavily import TavilyClient
+
+ self.client = TavilyClient(api_key=api_key)
+ self.embedder = embedder
+ self.max_results = max_results
+ self.search_depth = search_depth
+ self.include_answer = include_answer
+
+ import jieba.analyse
+
+ self.zh_fast_keywords_extractor = jieba.analyse.TextRank()
+
+ def _extract_tags(self, title: str, content: str, summary: str, parsed_goal=None) -> list[str]:
+ """
+ Extract tags from title, content and summary.
+
+ Args:
+ title: Article title
+ content: Article content
+ summary: Article summary
+ parsed_goal: Parsed task goal (optional)
+
+ Returns:
+ List of extracted tags
+ """
+ tags = []
+
+ tags.append("tavily_search")
+ tags.append("news")
+
+ text = f"{title} {content} {summary}".lower()
+
+ keywords = {
+ "economy": [
+ "economy",
+ "GDP",
+ "growth",
+ "production",
+ "industry",
+ "investment",
+ "consumption",
+ "market",
+ "trade",
+ "finance",
+ ],
+ "politics": [
+ "politics",
+ "government",
+ "policy",
+ "meeting",
+ "leader",
+ "election",
+ "parliament",
+ "ministry",
+ ],
+ "technology": [
+ "technology",
+ "tech",
+ "innovation",
+ "digital",
+ "internet",
+ "AI",
+ "artificial intelligence",
+ "software",
+ "hardware",
+ ],
+ "sports": [
+ "sports",
+ "game",
+ "athlete",
+ "olympic",
+ "championship",
+ "tournament",
+ "team",
+ "player",
+ ],
+ "culture": [
+ "culture",
+ "education",
+ "art",
+ "history",
+ "literature",
+ "music",
+ "film",
+ "museum",
+ ],
+ "health": [
+ "health",
+ "medical",
+ "pandemic",
+ "hospital",
+ "doctor",
+ "medicine",
+ "disease",
+ "treatment",
+ ],
+ "environment": [
+ "environment",
+ "ecology",
+ "pollution",
+ "green",
+ "climate",
+ "sustainability",
+ "renewable",
+ ],
+ }
+
+ for category, words in keywords.items():
+ if any(word in text for word in words):
+ tags.append(category)
+
+ if parsed_goal and hasattr(parsed_goal, "tags"):
+ tags.extend(parsed_goal.tags)
+
+ return list(set(tags))[:15]
+
+ def retrieve_from_internet(
+ self, query: str, top_k: int = 10, parsed_goal=None, info=None, mode="fast"
+ ) -> list[TextualMemoryItem]:
+ """
+ Retrieve results from the internet using Tavily Search API.
+
+ Args:
+ query: Search query
+ top_k: Number of results to retrieve
+ parsed_goal: Parsed task goal (optional)
+ info (dict): Metadata for memory consumption tracking
+ mode: Retrieval mode ('fast' or other)
+
+ Returns:
+ List of TextualMemoryItem
+ """
+ try:
+ response = self.client.search(
+ query=query,
+ max_results=min(top_k, self.max_results),
+ search_depth=self.search_depth,
+ include_answer=self.include_answer,
+ )
+ search_results = response.get("results", [])
+ except Exception:
+ import traceback
+
+ logger.error(f"Tavily search error: {traceback.format_exc()}")
+ return []
+
+ return self._convert_to_mem_items(search_results, query, parsed_goal, info, mode=mode)
+
+ def _convert_to_mem_items(
+ self, search_results: list[dict], query: str, parsed_goal=None, info=None, mode="fast"
+ ):
+ """Convert Tavily search results into TextualMemoryItem objects."""
+ memory_items = []
+ if not info:
+ info = {"user_id": "", "session_id": ""}
+
+ with ContextThreadPoolExecutor(max_workers=8) as executor:
+ futures = [
+ executor.submit(self._process_result, r, query, parsed_goal, info, mode=mode)
+ for r in search_results
+ ]
+ for future in as_completed(futures):
+ try:
+ memory_items.extend(future.result())
+ except Exception as e:
+ logger.error(f"Error processing Tavily search result: {e}")
+
+ unique_memory_items = {item.memory: item for item in memory_items}
+ return list(unique_memory_items.values())
+
+ def _process_result(
+ self, result: dict, query: str, parsed_goal: str, info: dict[str, Any], mode="fast"
+ ) -> list[TextualMemoryItem]:
+ """Process one Tavily search result into TextualMemoryItem."""
+ title = result.get("title", "")
+ content = result.get("content", "")
+ url = result.get("url", "")
+ publish_time = result.get("published_date", "")
+
+ if publish_time:
+ try:
+ publish_time = datetime.fromisoformat(publish_time.replace("Z", "+00:00")).strftime(
+ "%Y-%m-%d"
+ )
+ except Exception:
+ publish_time = datetime.now().strftime("%Y-%m-%d")
+ else:
+ publish_time = datetime.now().strftime("%Y-%m-%d")
+
+ summary = content[:500] if content else ""
+
+ if mode == "fast":
+ info_ = info.copy()
+ user_id = info_.pop("user_id", "")
+ session_id = info_.pop("session_id", "")
+ lang = detect_lang(summary)
+ tags = (
+ self.zh_fast_keywords_extractor.textrank(summary, topK=3)[:3]
+ if lang == "zh"
+ else self._extract_tags(title, content, summary)[:3]
+ )
+
+ return [
+ TextualMemoryItem(
+ memory=(
+ f"[Outer internet view] Title: {title}\nNewsTime:"
+ f" {publish_time}\nSummary:"
+ f" {summary}\n"
+ ),
+ metadata=SearchedTreeNodeTextualMemoryMetadata(
+ user_id=user_id,
+ session_id=session_id,
+ memory_type="OuterMemory",
+ status="activated",
+ type="fact",
+ source="web",
+ sources=[SourceMessage(type="web", url=url)] if url else [],
+ visibility="public",
+ info=info_,
+ background="",
+ confidence=0.99,
+ usage=[],
+ tags=tags,
+ key=title,
+ embedding=self.embedder.embed([content])[0],
+ internet_info={
+ "title": title,
+ "url": url,
+ "site_name": "",
+ "site_icon": None,
+ "summary": summary,
+ },
+ ),
+ )
+ ]
+ else:
+ info_ = info.copy()
+ user_id = info_.pop("user_id", "")
+ session_id = info_.pop("session_id", "")
+ lang = detect_lang(summary)
+ tags = (
+ self.zh_fast_keywords_extractor.textrank(summary, topK=3)[:3]
+ if lang == "zh"
+ else self._extract_tags(title, content, summary)[:3]
+ )
+
+ return [
+ TextualMemoryItem(
+ memory=(
+ f"[Outer internet view] Title: {title}\nNewsTime:"
+ f" {publish_time}\nSummary:"
+ f" {summary}\n"
+ f"Content: {content}"
+ ),
+ metadata=SearchedTreeNodeTextualMemoryMetadata(
+ user_id=user_id,
+ session_id=session_id,
+ memory_type="OuterMemory",
+ status="activated",
+ type="fact",
+ source="web",
+ sources=[SourceMessage(type="web", url=url)] if url else [],
+ visibility="public",
+ info=info_,
+ background="",
+ confidence=0.99,
+ usage=[],
+ tags=tags,
+ key=title,
+ embedding=self.embedder.embed([content])[0],
+ internet_info={
+ "title": title,
+ "url": url,
+ "site_name": "",
+ "site_icon": None,
+ "summary": summary,
+ },
+ ),
+ )
+ ]
diff --git a/tests/api/test_memory_handler_delete.py b/tests/api/test_memory_handler_delete.py
new file mode 100644
index 000000000..7d60f946b
--- /dev/null
+++ b/tests/api/test_memory_handler_delete.py
@@ -0,0 +1,142 @@
+from unittest.mock import Mock
+
+from memos.api.handlers.memory_handler import handle_delete_memories
+from memos.api.product_models import DeleteMemoryRequest
+
+
+def _build_naive_mem_cube() -> Mock:
+ naive_mem_cube = Mock()
+ naive_mem_cube.text_mem = Mock()
+ naive_mem_cube.pref_mem = Mock()
+ return naive_mem_cube
+
+
+def test_delete_memories_quick_by_user_id():
+ naive_mem_cube = _build_naive_mem_cube()
+ req = DeleteMemoryRequest(user_id="u_1")
+
+ resp = handle_delete_memories(req, naive_mem_cube)
+
+ assert resp.data["status"] == "success"
+ naive_mem_cube.text_mem.delete_by_filter.assert_called_once_with(
+ writable_cube_ids=None,
+ filter={"and": [{"user_id": "u_1"}]},
+ )
+ naive_mem_cube.pref_mem.delete_by_filter.assert_called_once_with(
+ filter={"and": [{"user_id": "u_1"}]}
+ )
+
+
+def test_delete_memories_quick_by_conversation_alias():
+ naive_mem_cube = _build_naive_mem_cube()
+ req = DeleteMemoryRequest(conversation_id="conv_1")
+
+ assert req.session_id == "conv_1"
+
+ resp = handle_delete_memories(req, naive_mem_cube)
+
+ assert resp.data["status"] == "success"
+ naive_mem_cube.text_mem.delete_by_filter.assert_called_once_with(
+ writable_cube_ids=None,
+ filter={"and": [{"session_id": "conv_1"}]},
+ )
+ naive_mem_cube.pref_mem.delete_by_filter.assert_called_once_with(
+ filter={"and": [{"session_id": "conv_1"}]}
+ )
+
+
+def test_delete_memories_filter_and_quick_conditions():
+ naive_mem_cube = _build_naive_mem_cube()
+ req = DeleteMemoryRequest(
+ filter={"and": [{"memory_type": "WorkingMemory"}]},
+ user_id="u_1",
+ session_id="s_1",
+ )
+
+ resp = handle_delete_memories(req, naive_mem_cube)
+
+ assert resp.data["status"] == "success"
+ naive_mem_cube.text_mem.delete_by_filter.assert_called_once_with(
+ writable_cube_ids=None,
+ filter={
+ "and": [
+ {"memory_type": "WorkingMemory"},
+ {"user_id": "u_1", "session_id": "s_1"},
+ ]
+ },
+ )
+ naive_mem_cube.pref_mem.delete_by_filter.assert_called_once_with(
+ filter={
+ "and": [
+ {"memory_type": "WorkingMemory"},
+ {"user_id": "u_1", "session_id": "s_1"},
+ ]
+ }
+ )
+
+
+def test_delete_memories_filter_or_with_distribution():
+ naive_mem_cube = _build_naive_mem_cube()
+ req = DeleteMemoryRequest(
+ filter={"or": [{"memory_type": "WorkingMemory"}, {"memory_type": "UserMemory"}]},
+ user_id="u_1",
+ )
+
+ resp = handle_delete_memories(req, naive_mem_cube)
+
+ assert resp.data["status"] == "success"
+ naive_mem_cube.text_mem.delete_by_filter.assert_called_once_with(
+ writable_cube_ids=None,
+ filter={
+ "or": [
+ {"memory_type": "WorkingMemory", "user_id": "u_1"},
+ {"memory_type": "UserMemory", "user_id": "u_1"},
+ ]
+ },
+ )
+ naive_mem_cube.pref_mem.delete_by_filter.assert_called_once_with(
+ filter={
+ "or": [
+ {"memory_type": "WorkingMemory", "user_id": "u_1"},
+ {"memory_type": "UserMemory", "user_id": "u_1"},
+ ]
+ }
+ )
+
+
+def test_delete_memories_reject_multiple_modes():
+ naive_mem_cube = _build_naive_mem_cube()
+ req = DeleteMemoryRequest(memory_ids=["m_1"], user_id="u_1")
+
+ resp = handle_delete_memories(req, naive_mem_cube)
+
+ assert resp.data["status"] == "failure"
+ assert "Exactly one delete mode must be provided" in resp.message
+ naive_mem_cube.text_mem.delete_by_filter.assert_not_called()
+ naive_mem_cube.text_mem.delete_by_memory_ids.assert_not_called()
+
+
+def test_delete_memories_reject_empty_filter():
+ naive_mem_cube = _build_naive_mem_cube()
+ req = DeleteMemoryRequest(filter={})
+
+ resp = handle_delete_memories(req, naive_mem_cube)
+
+ assert resp.data["status"] == "failure"
+ assert "filter cannot be empty" in resp.message
+ naive_mem_cube.text_mem.delete_by_filter.assert_not_called()
+ naive_mem_cube.pref_mem.delete_by_filter.assert_not_called()
+
+
+def test_delete_memories_with_pref_mem_disabled():
+ naive_mem_cube = _build_naive_mem_cube()
+ naive_mem_cube.pref_mem = None
+ req = DeleteMemoryRequest(user_id="u_1")
+
+ resp = handle_delete_memories(req, naive_mem_cube)
+
+ assert resp.data["status"] == "success"
+ naive_mem_cube.text_mem.delete_by_filter.assert_called_once_with(
+ writable_cube_ids=None,
+ filter={"and": [{"user_id": "u_1"}]},
+ )
diff --git a/tests/api/test_product_router.py b/tests/api/test_product_router.py
deleted file mode 100644
index 857b290c5..000000000
--- a/tests/api/test_product_router.py
+++ /dev/null
@@ -1,422 +0,0 @@
-"""
-Unit tests for product_router input/output format validation.
-
-This module tests that the product_router endpoints correctly validate
-input request formats and return properly formatted responses.
-"""
-
-from unittest.mock import Mock, patch
-
-import pytest
-
-from fastapi.testclient import TestClient
-
-# Patch the MOS_PRODUCT_INSTANCE directly after import
-# Patch MOS_PRODUCT_INSTANCE and MOSProduct so we can test the FastAPI router
-# without initializing the full MemOS product stack.
-import memos.api.routers.product_router as pr_module
-
-
-_mock_mos_instance = Mock()
-pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance
-pr_module.get_mos_product_instance = lambda: _mock_mos_instance
-with patch("memos.mem_os.product.MOSProduct", return_value=_mock_mos_instance):
- from memos.api import product_api
-
-
-@pytest.fixture(scope="module")
-def mock_mos_product_instance():
- """Mock get_mos_product_instance for all tests."""
- # Ensure the mock is set
- pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance
- pr_module.get_mos_product_instance = lambda: _mock_mos_instance
- yield product_api.app, _mock_mos_instance
-
-
-@pytest.fixture
-def client(mock_mos_product_instance):
- """Create test client for product_api."""
- app, _ = mock_mos_product_instance
- return TestClient(app)
-
-
-@pytest.fixture
-def mock_mos_product(mock_mos_product_instance):
- """Get the mocked MOSProduct instance."""
- _, mock_instance = mock_mos_product_instance
- # Ensure get_mos_product_instance returns this mock
- import memos.api.routers.product_router as pr_module
-
- pr_module.get_mos_product_instance = lambda: mock_instance
- pr_module.MOS_PRODUCT_INSTANCE = mock_instance
- return mock_instance
-
-
-@pytest.fixture(autouse=True)
-def setup_mock_mos_product(mock_mos_product):
- """Set up default return values for MOSProduct methods."""
- # Set up default return values for methods
- mock_mos_product.search.return_value = {"text_mem": [], "act_mem": [], "para_mem": []}
- mock_mos_product.add.return_value = None
- mock_mos_product.chat.return_value = ("test response", [])
- mock_mos_product.chat_with_references.return_value = iter(
- ['data: {"type": "content", "data": "test"}\n\n']
- )
- # Ensure get_all and get_subgraph return proper list format (MemoryResponse expects list)
- default_memory_result = [{"cube_id": "test_cube", "memories": []}]
- mock_mos_product.get_all.return_value = default_memory_result
- mock_mos_product.get_subgraph.return_value = default_memory_result
- mock_mos_product.get_suggestion_query.return_value = ["suggestion1", "suggestion2"]
- # Ensure get_mos_product_instance returns the mock
- import memos.api.routers.product_router as pr_module
-
- pr_module.get_mos_product_instance = lambda: mock_mos_product
-
-
-class TestProductRouterSearch:
- """Test /search endpoint input/output format."""
-
- def test_search_valid_input_output(self, mock_mos_product, client):
- """Test search endpoint with valid input returns correct output format."""
- request_data = {
- "user_id": "test_user",
- "query": "test query",
- "mem_cube_id": "test_cube",
- "top_k": 10,
- }
-
- response = client.post("/product/search", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
-
- # Validate response structure
- assert "code" in data
- assert "message" in data
- assert "data" in data
- assert data["code"] == 200
- assert isinstance(data["data"], dict)
-
- # Verify MOSProduct.search was called with correct parameters
- mock_mos_product.search.assert_called_once()
- call_kwargs = mock_mos_product.search.call_args[1]
- assert call_kwargs["user_id"] == "test_user"
- assert call_kwargs["query"] == "test query"
-
- def test_search_invalid_input_missing_user_id(self, mock_mos_product, client):
- """Test search endpoint with missing required field."""
- request_data = {
- "query": "test query",
- }
-
- response = client.post("/product/search", json=request_data)
-
- # Should return validation error
- assert response.status_code == 422
-
- def test_search_response_format(self, mock_mos_product, client):
- """Test search endpoint returns SearchResponse format."""
- mock_mos_product.search.return_value = {
- "text_mem": [{"cube_id": "test_cube", "memories": []}],
- "act_mem": [],
- "para_mem": [],
- }
-
- request_data = {
- "user_id": "test_user",
- "query": "test query",
- }
-
- response = client.post("/product/search", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
- assert data["message"] == "Search completed successfully"
- assert isinstance(data["data"], dict)
- assert "text_mem" in data["data"]
-
-
-class TestProductRouterAdd:
- """Test /add endpoint input/output format."""
-
- def test_add_valid_input_output(self, mock_mos_product, client):
- """Test add endpoint with valid input returns correct output format."""
- request_data = {
- "user_id": "test_user",
- "memory_content": "test memory content",
- "mem_cube_id": "test_cube",
- }
-
- response = client.post("/product/add", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
-
- # Validate response structure
- assert "code" in data
- assert "message" in data
- assert "data" in data
- assert data["code"] == 200
- assert data["data"] is None # SimpleResponse has None data
-
- # Verify MOSProduct.add was called with correct parameters
- mock_mos_product.add.assert_called_once()
- call_kwargs = mock_mos_product.add.call_args[1]
- assert call_kwargs["user_id"] == "test_user"
- assert call_kwargs["memory_content"] == "test memory content"
-
- def test_add_invalid_input_missing_user_id(self, mock_mos_product, client):
- """Test add endpoint with missing required field."""
- request_data = {
- "memory_content": "test memory content",
- }
-
- response = client.post("/product/add", json=request_data)
-
- # Should return validation error
- assert response.status_code == 422
-
- def test_add_response_format(self, mock_mos_product, client):
- """Test add endpoint returns SimpleResponse format."""
- request_data = {
- "user_id": "test_user",
- "memory_content": "test memory content",
- }
-
- response = client.post("/product/add", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
- assert data["message"] == "Memory created successfully"
- assert data["data"] is None
-
-
-class TestProductRouterChatComplete:
- """Test /chat/complete endpoint input/output format."""
-
- def test_chat_complete_valid_input_output(self, mock_mos_product, client):
- """Test chat/complete endpoint with valid input returns correct output format."""
- request_data = {
- "user_id": "test_user",
- "query": "test query",
- "mem_cube_id": "test_cube",
- }
-
- response = client.post("/product/chat/complete", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
-
- # Validate response structure
- assert "message" in data
- assert "data" in data
- assert isinstance(data["data"], dict)
- assert "response" in data["data"]
- assert "references" in data["data"]
-
- # Verify MOSProduct.chat was called with correct parameters
- mock_mos_product.chat.assert_called_once()
- call_kwargs = mock_mos_product.chat.call_args[1]
- assert call_kwargs["user_id"] == "test_user"
- assert call_kwargs["query"] == "test query"
-
- def test_chat_complete_invalid_input_missing_user_id(self, mock_mos_product, client):
- """Test chat/complete endpoint with missing required field."""
- request_data = {
- "query": "test query",
- }
-
- response = client.post("/product/chat/complete", json=request_data)
-
- # Should return validation error
- assert response.status_code == 422
-
- def test_chat_complete_response_format(self, mock_mos_product, client):
- """Test chat/complete endpoint returns correct format."""
- mock_mos_product.chat.return_value = ("test response", [{"id": "ref1"}])
-
- request_data = {
- "user_id": "test_user",
- "query": "test query",
- }
-
- response = client.post("/product/chat/complete", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
- assert data["message"] == "Chat completed successfully"
- assert isinstance(data["data"]["response"], str)
- assert isinstance(data["data"]["references"], list)
-
-
-class TestProductRouterChat:
- """Test /chat endpoint input/output format (SSE stream)."""
-
- def test_chat_valid_input_output(self, mock_mos_product, client):
- """Test chat endpoint with valid input returns SSE stream."""
- request_data = {
- "user_id": "test_user",
- "query": "test query",
- "mem_cube_id": "test_cube",
- }
-
- response = client.post("/product/chat", json=request_data)
-
- assert response.status_code == 200
- assert "text/event-stream" in response.headers["content-type"]
-
- # Verify MOSProduct.chat_with_references was called
- mock_mos_product.chat_with_references.assert_called_once()
- call_kwargs = mock_mos_product.chat_with_references.call_args[1]
- assert call_kwargs["user_id"] == "test_user"
- assert call_kwargs["query"] == "test query"
-
- def test_chat_invalid_input_missing_user_id(self, mock_mos_product, client):
- """Test chat endpoint with missing required field."""
- request_data = {
- "query": "test query",
- }
-
- response = client.post("/product/chat", json=request_data)
-
- # Should return validation error
- assert response.status_code == 422
-
-
-class TestProductRouterSuggestions:
- """Test /suggestions endpoint input/output format."""
-
- def test_suggestions_valid_input_output(self, mock_mos_product, client):
- """Test suggestions endpoint with valid input returns correct output format."""
- request_data = {
- "user_id": "test_user",
- "mem_cube_id": "test_cube",
- "language": "zh",
- }
-
- response = client.post("/product/suggestions", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
-
- # Validate response structure
- assert "code" in data
- assert "message" in data
- assert "data" in data
- assert data["code"] == 200
- assert isinstance(data["data"], dict)
- assert "query" in data["data"]
-
- # Verify MOSProduct.get_suggestion_query was called
- mock_mos_product.get_suggestion_query.assert_called_once()
- call_kwargs = mock_mos_product.get_suggestion_query.call_args[1]
- assert call_kwargs["user_id"] == "test_user"
-
- def test_suggestions_invalid_input_missing_user_id(self, mock_mos_product, client):
- """Test suggestions endpoint with missing required field."""
- request_data = {
- "mem_cube_id": "test_cube",
- }
-
- response = client.post("/product/suggestions", json=request_data)
-
- # Should return validation error
- assert response.status_code == 422
-
- def test_suggestions_response_format(self, mock_mos_product, client):
- """Test suggestions endpoint returns SuggestionResponse format."""
- mock_mos_product.get_suggestion_query.return_value = [
- "suggestion1",
- "suggestion2",
- "suggestion3",
- ]
-
- request_data = {
- "user_id": "test_user",
- "mem_cube_id": "test_cube",
- "language": "en",
- }
-
- response = client.post("/product/suggestions", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
- assert data["message"] == "Suggestions retrieved successfully"
- assert isinstance(data["data"], dict)
- assert isinstance(data["data"]["query"], list)
-
-
-class TestProductRouterGetAll:
- """Test /get_all endpoint input/output format."""
-
- def test_get_all_valid_input_output(self, mock_mos_product, client):
- """Test get_all endpoint with valid input returns correct output format."""
- request_data = {
- "user_id": "test_user",
- "memory_type": "text_mem",
- }
-
- response = client.post("/product/get_all", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
-
- # Validate response structure
- assert "code" in data
- assert "message" in data
- assert "data" in data
- assert data["code"] == 200
- assert isinstance(data["data"], list)
-
- # Verify MOSProduct.get_all was called
- mock_mos_product.get_all.assert_called_once()
- call_kwargs = mock_mos_product.get_all.call_args[1]
- assert call_kwargs["user_id"] == "test_user"
- assert call_kwargs["memory_type"] == "text_mem"
-
- def test_get_all_with_search_query(self, mock_mos_product, client):
- """Test get_all endpoint with search_query uses get_subgraph."""
- # Reset mock call counts
- mock_mos_product.get_all.reset_mock()
- mock_mos_product.get_subgraph.reset_mock()
-
- request_data = {
- "user_id": "test_user",
- "memory_type": "text_mem",
- "search_query": "test query",
- }
-
- response = client.post("/product/get_all", json=request_data)
-
- assert response.status_code == 200
- # Verify get_subgraph was called instead of get_all
- mock_mos_product.get_subgraph.assert_called_once()
- mock_mos_product.get_all.assert_not_called()
-
- def test_get_all_invalid_input_missing_user_id(self, mock_mos_product, client):
- """Test get_all endpoint with missing required field."""
- request_data = {
- "memory_type": "text_mem",
- }
-
- response = client.post("/product/get_all", json=request_data)
-
- # Should return validation error
- assert response.status_code == 422
-
- def test_get_all_response_format(self, mock_mos_product, client):
- """Test get_all endpoint returns MemoryResponse format."""
- mock_mos_product.get_all.return_value = [{"cube_id": "test_cube", "memories": []}]
-
- request_data = {
- "user_id": "test_user",
- "memory_type": "text_mem",
- }
-
- response = client.post("/product/get_all", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
- assert data["message"] == "Memories retrieved successfully"
- assert isinstance(data["data"], list)
- assert len(data["data"]) > 0
diff --git a/tests/api/test_start_api.py b/tests/api/test_start_api.py
deleted file mode 100644
index e1ffcd74b..000000000
--- a/tests/api/test_start_api.py
+++ /dev/null
@@ -1,401 +0,0 @@
-from unittest.mock import Mock, patch
-
-import pytest
-
-from fastapi.testclient import TestClient
-
-from memos.api.start_api import app
-from memos.mem_user.user_manager import UserRole
-
-
-client = TestClient(app)
-
-# Mock data
-MOCK_MESSAGE = {"role": "user", "content": "test message"}
-MOCK_MEMORY_CREATE = {
- "messages": [MOCK_MESSAGE],
- "mem_cube_id": "test_cube",
- "user_id": "test_user",
-}
-MOCK_MEMORY_CONTENT = {
- "memory_content": "test memory content",
- "mem_cube_id": "test_cube",
- "user_id": "test_user",
-}
-MOCK_DOC_PATH = {"doc_path": "/path/to/doc", "mem_cube_id": "test_cube", "user_id": "test_user"}
-MOCK_SEARCH_REQUEST = {
- "query": "test query",
- "user_id": "test_user",
- "install_cube_ids": ["test_cube"],
-}
-MOCK_MEMCUBE_REGISTER = {
- "mem_cube_name_or_path": "test_cube_path",
- "mem_cube_id": "test_cube",
- "user_id": "test_user",
-}
-MOCK_CHAT_REQUEST = {"query": "test chat query", "user_id": "test_user"}
-MOCK_USER_CREATE = {"user_id": "test_user", "user_name": "Test User", "role": "USER"}
-MOCK_CUBE_SHARE = {"target_user_id": "target_user"}
-MOCK_CONFIG = {
- "user_id": "test_user",
- "session_id": "test_session",
- "enable_textual_memory": True,
- "enable_activation_memory": False,
- "top_k": 5,
- "chat_model": {
- "backend": "openai",
- "config": {
- "model_name_or_path": "gpt-3.5-turbo",
- "api_key": "test_key",
- "temperature": 0.7,
- "api_base": "https://api.openai.com/v1",
- },
- },
-}
-
-
-@pytest.fixture
-def mock_mos():
- """Mock MOS instance for testing."""
- with patch("memos.api.start_api.get_mos_instance") as mock_get_mos:
- # Create a mock MOS instance
- mock_instance = Mock()
-
- # Set up default return values for methods
- mock_instance.search.return_value = {"text_mem": [], "act_mem": [], "para_mem": []}
- mock_instance.get_all.return_value = {"text_mem": [], "act_mem": [], "para_mem": []}
- mock_instance.get.return_value = {"memory": "test memory"}
- mock_instance.chat.return_value = "test response"
- mock_instance.list_users.return_value = []
- mock_instance.get_user_info.return_value = {
- "user_id": "test_user",
- "user_name": "Test User",
- "role": "user",
- "accessible_cubes": [],
- }
- mock_instance.create_user.return_value = "test_user"
- mock_instance.share_cube_with_user.return_value = True
-
- # Configure the mock to return our mock instance
- mock_get_mos.return_value = mock_instance
-
- yield mock_instance
-
-
-def test_configure_error(mock_mos):
- """Test configuration endpoint with error."""
- with patch("memos.api.start_api.MOS_INSTANCE", None):
- response = client.post("/configure", json={})
- assert response.status_code == 422 # FastAPI validation error
-
-
-def test_create_user(mock_mos):
- """Test user creation endpoint."""
- response = client.post("/users", json=MOCK_USER_CREATE)
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "User created successfully",
- "data": {"user_id": "test_user"},
- }
- mock_mos.create_user.assert_called_once_with(
- user_id="test_user", role=UserRole.USER, user_name="Test User"
- )
-
-
-def test_create_user_validation_error(mock_mos):
- """Test user creation with validation error."""
- mock_mos.create_user.side_effect = ValueError("Invalid user data")
- response = client.post("/users", json=MOCK_USER_CREATE)
- assert response.status_code == 400
- assert "Invalid user data" in response.json()["message"]
-
-
-def test_list_users(mock_mos):
- """Test list users endpoint."""
- # Set up mock to return the expected data structure
- mock_users = [
- {
- "user_id": "test_user",
- "user_name": "Test User",
- "role": "user",
- "created_at": "2023-01-01T00:00:00",
- "is_active": True,
- }
- ]
- mock_mos.list_users.return_value = mock_users
-
- response = client.get("/users")
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "Users retrieved successfully",
- "data": mock_users,
- }
- mock_mos.list_users.assert_called_once()
-
-
-def test_get_user_info(mock_mos):
- """Test get user info endpoint."""
- # Set up mock to return the expected data structure
- mock_user_info = {
- "user_id": "test_user",
- "user_name": "Test User",
- "role": "user",
- "created_at": "2023-01-01T00:00:00",
- "accessible_cubes": [],
- }
- mock_mos.get_user_info.return_value = mock_user_info
-
- response = client.get("/users/me")
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "User info retrieved successfully",
- "data": mock_user_info,
- }
- mock_mos.get_user_info.assert_called_once()
-
-
-def test_register_mem_cube(mock_mos):
- """Test MemCube registration endpoint."""
- response = client.post("/mem_cubes", json=MOCK_MEMCUBE_REGISTER)
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "MemCube registered successfully",
- "data": None,
- }
- mock_mos.register_mem_cube.assert_called_once_with(
- mem_cube_name_or_path="test_cube_path", mem_cube_id="test_cube", user_id="test_user"
- )
-
-
-def test_register_mem_cube_validation_error(mock_mos):
- """Test MemCube registration with validation error."""
- mock_mos.register_mem_cube.side_effect = ValueError("Invalid MemCube")
- response = client.post("/mem_cubes", json=MOCK_MEMCUBE_REGISTER)
- assert response.status_code == 400
- assert "Invalid MemCube" in response.json()["message"]
-
-
-def test_unregister_mem_cube(mock_mos):
- """Test MemCube unregistration endpoint."""
- response = client.delete("/mem_cubes/test_cube?user_id=test_user")
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "MemCube unregistered successfully",
- "data": None,
- }
- mock_mos.unregister_mem_cube.assert_called_once_with(
- mem_cube_id="test_cube", user_id="test_user"
- )
-
-
-def test_unregister_nonexistent_mem_cube(mock_mos):
- """Test unregistering a non-existent MemCube."""
- mock_mos.unregister_mem_cube.side_effect = ValueError("MemCube not found")
- response = client.delete("/mem_cubes/nonexistent_cube")
- assert response.status_code == 400
- assert "MemCube not found" in response.json()["message"]
-
-
-def test_share_cube(mock_mos):
- """Test cube sharing endpoint."""
- response = client.post("/mem_cubes/test_cube/share", json=MOCK_CUBE_SHARE)
- assert response.status_code == 200
- assert response.json() == {"code": 200, "message": "Cube shared successfully", "data": None}
- mock_mos.share_cube_with_user.assert_called_once_with("test_cube", "target_user")
-
-
-def test_share_cube_failure(mock_mos):
- """Test cube sharing failure."""
- mock_mos.share_cube_with_user.return_value = False
- response = client.post("/mem_cubes/test_cube/share", json=MOCK_CUBE_SHARE)
- assert response.status_code == 400
- assert "Failed to share cube" in response.json()["message"]
-
-
-@pytest.mark.parametrize(
- "memory_create,expected_calls",
- [
- (MOCK_MEMORY_CREATE, {"messages": [MOCK_MESSAGE]}),
- (MOCK_MEMORY_CONTENT, {"memory_content": "test memory content"}),
- (MOCK_DOC_PATH, {"doc_path": "/path/to/doc"}),
- ],
-)
-def test_add_memory(mock_mos, memory_create, expected_calls):
- """Test adding memories with different types of content."""
- response = client.post("/memories", json=memory_create)
- assert response.status_code == 200
- assert response.json() == {"code": 200, "message": "Memories added successfully", "data": None}
- mock_mos.add.assert_called_once()
-
-
-def test_add_memory_validation_error(mock_mos):
- """Test adding memory with validation error."""
- response = client.post("/memories", json={})
- assert response.status_code == 400
- assert "must be provided" in response.json()["message"]
-
-
-def test_get_all_memories(mock_mos):
- """Test get all memories endpoint."""
- mock_results = {
- "text_mem": [{"cube_id": "test_cube", "memories": []}],
- "act_mem": [],
- "para_mem": [],
- }
- mock_mos.get_all.return_value = mock_results
-
- response = client.get("/memories")
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "Memories retrieved successfully",
- "data": mock_results,
- }
- mock_mos.get_all.assert_called_once_with(mem_cube_id=None, user_id=None)
-
-
-def test_get_memory(mock_mos):
- """Test get specific memory endpoint."""
- mock_memory = {"memory": "test memory content"}
- mock_mos.get.return_value = mock_memory
-
- response = client.get("/memories/test_cube/test_memory")
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "Memory retrieved successfully",
- "data": mock_memory,
- }
- mock_mos.get.assert_called_once_with(
- mem_cube_id="test_cube", memory_id="test_memory", user_id=None
- )
-
-
-def test_get_nonexistent_memory(mock_mos):
- """Test getting a non-existent memory."""
- mock_mos.get.side_effect = ValueError("Memory not found")
- response = client.get("/memories/test_cube/nonexistent_memory")
- assert response.status_code == 400
- assert "Memory not found" in response.json()["message"]
-
-
-def test_search_memories(mock_mos):
- """Test search memories endpoint."""
- # Mock the search method to return a proper result structure
- mock_results = {"text_mem": [], "act_mem": [], "para_mem": []}
- mock_mos.search.return_value = mock_results
-
- # Ensure the search request has all required fields
- search_request = {
- "query": "test query",
- "user_id": "test_user",
- "install_cube_ids": ["test_cube"],
- }
-
- response = client.post("/search", json=search_request)
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "Search completed successfully",
- "data": mock_results,
- }
- mock_mos.search.assert_called_once_with(
- query="test query", user_id="test_user", install_cube_ids=["test_cube"]
- )
-
-
-def test_update_memory(mock_mos):
- """Test updating a memory endpoint."""
- update_data = {"content": "updated content"}
- response = client.put("/memories/test_cube/test_memory?user_id=test_user", json=update_data)
- assert response.status_code == 200
- assert response.json() == {"code": 200, "message": "Memory updated successfully", "data": None}
- mock_mos.update.assert_called_once_with(
- mem_cube_id="test_cube",
- memory_id="test_memory",
- text_memory_item=update_data,
- user_id="test_user",
- )
-
-
-def test_update_nonexistent_memory(mock_mos):
- """Test updating a non-existent memory."""
- mock_mos.update.side_effect = ValueError("Memory not found")
- response = client.put("/memories/test_cube/nonexistent_memory", json={})
- assert response.status_code == 400
- assert "Memory not found" in response.json()["message"]
-
-
-def test_delete_memory(mock_mos):
- """Test deleting a memory endpoint."""
- response = client.delete("/memories/test_cube/test_memory?user_id=test_user")
- assert response.status_code == 200
- assert response.json() == {"code": 200, "message": "Memory deleted successfully", "data": None}
- mock_mos.delete.assert_called_once_with(
- mem_cube_id="test_cube", memory_id="test_memory", user_id="test_user"
- )
-
-
-def test_delete_nonexistent_memory(mock_mos):
- """Test deleting a non-existent memory."""
- mock_mos.delete.side_effect = ValueError("Memory not found")
- response = client.delete("/memories/test_cube/nonexistent_memory")
- assert response.status_code == 400
- assert "Memory not found" in response.json()["message"]
-
-
-def test_delete_all_memories(mock_mos):
- """Test deleting all memories endpoint."""
- response = client.delete("/memories/test_cube?user_id=test_user")
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "All memories deleted successfully",
- "data": None,
- }
- mock_mos.delete_all.assert_called_once_with(mem_cube_id="test_cube", user_id="test_user")
-
-
-def test_delete_all_nonexistent_memories(mock_mos):
- """Test deleting all memories from non-existent MemCube."""
- mock_mos.delete_all.side_effect = ValueError("MemCube not found")
- response = client.delete("/memories/nonexistent_cube")
- assert response.status_code == 400
- assert "MemCube not found" in response.json()["message"]
-
-
-def test_chat(mock_mos):
- """Test chat endpoint."""
- response = client.post("/chat", json=MOCK_CHAT_REQUEST)
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "Chat response generated",
- "data": "test response",
- }
- mock_mos.chat.assert_called_once_with(query="test chat query", user_id="test_user")
-
-
-def test_chat_without_user_id(mock_mos):
- """Test chat endpoint without user_id."""
- chat_request = {"query": "test chat query"}
- response = client.post("/chat", json=chat_request)
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "Chat response generated",
- "data": "test response",
- }
- mock_mos.chat.assert_called_once_with(query="test chat query", user_id=None)
-
-
-def test_home_redirect():
- """Test home endpoint redirects to docs."""
- response = client.get("/", follow_redirects=False)
- assert response.status_code == 307
- assert response.headers["location"] == "/docs"
diff --git a/tests/configs/test_llm.py b/tests/configs/test_llm.py
index f3d4549b5..d1507e389 100644
--- a/tests/configs/test_llm.py
+++ b/tests/configs/test_llm.py
@@ -2,6 +2,7 @@
BaseLLMConfig,
HFLLMConfig,
LLMConfigFactory,
+ MinimaxLLMConfig,
OllamaLLMConfig,
OpenAILLMConfig,
)
@@ -145,6 +146,42 @@ def test_hf_llm_config():
check_config_instantiation_invalid(HFLLMConfig)
+def test_minimax_llm_config():
+ check_config_base_class(
+ MinimaxLLMConfig,
+ required_fields=["model_name_or_path", "api_key"],
+ optional_fields=[
+ "temperature",
+ "max_tokens",
+ "top_p",
+ "top_k",
+ "api_base",
+ "remove_think_prefix",
+ "extra_body",
+ "default_headers",
+ "backup_client",
+ "backup_api_key",
+ "backup_api_base",
+ "backup_model_name_or_path",
+ "backup_headers",
+ ],
+ )
+
+ check_config_instantiation_valid(
+ MinimaxLLMConfig,
+ {
+ "model_name_or_path": "MiniMax-M2.7",
+ "api_key": "test-key",
+ "api_base": "https://api.minimax.io/v1",
+ "temperature": 0.7,
+ "max_tokens": 1024,
+ "top_p": 0.9,
+ },
+ )
+
+ check_config_instantiation_invalid(MinimaxLLMConfig)
+
+
def test_llm_config_factory():
check_config_factory_class(
LLMConfigFactory,
diff --git a/tests/graph_dbs/graph_dbs.py b/tests/graph_dbs/graph_dbs.py
index 2cc35a0ad..c834f65a9 100644
--- a/tests/graph_dbs/graph_dbs.py
+++ b/tests/graph_dbs/graph_dbs.py
@@ -105,3 +105,26 @@ def test_get_memory_count(graph_db):
session_mock.run.return_value.single.return_value = {"count": 42}
count = graph_db.get_memory_count("WorkingMemory")
assert count == 42
+
+
+def test_add_node_sanitizes_nested_metadata(graph_db):
+ session_mock = graph_db.driver.session.return_value.__enter__.return_value
+ node_id = str(uuid.uuid4())
+ memory = "skill memory"
+ metadata = {
+ "memory_type": "SkillMemory",
+ "embedding": [0.1, 0.2, 0.3],
+ "tags": ["skill"],
+ "scripts": {"run.py": "print(1)"},
+ "others": {"README.md": "# demo"},
+ "info": {"nested": {"x": 1}, "arr_obj": [{"a": 1}]},
+ }
+
+ graph_db.add_node(node_id, memory, metadata)
+
+ _, kwargs = session_mock.run.call_args
+ sanitized = kwargs["metadata"]
+ assert isinstance(sanitized["scripts"], str)
+ assert isinstance(sanitized["others"], str)
+ assert isinstance(sanitized["nested"], str)
+ assert sanitized["arr_obj"] == ['{"a": 1}']
diff --git a/tests/graph_dbs/test_neo4j_vector_search.py b/tests/graph_dbs/test_neo4j_vector_search.py
new file mode 100644
index 000000000..3ed0b0587
--- /dev/null
+++ b/tests/graph_dbs/test_neo4j_vector_search.py
@@ -0,0 +1,425 @@
+"""
+Tests for Neo4j vector search pre-filtering and related regressions.
+
+- Unit tests: verify query structure (pre-filter vs ANN paths) using mocks
+- Integration tests: verify real search behavior with multi-user data (requires Neo4j 5.18+)
+
+The pre-filter approach (Neo4j 5.18+):
+ When WHERE filters are present (scope, status, user_name, etc.), the query uses
+ MATCH + WHERE to narrow candidates first, then vector.similarity.cosine()
+ computes similarity only on the filtered set. This avoids the post-filter
+ problem entirely โ no nodes are lost due to global top-k truncation.
+
+ When no filters are present, the ANN vector index (db.index.vector.queryNodes)
+ is used for maximum efficiency.
+"""
+
+import math
+import os
+import uuid
+
+from datetime import datetime, timezone
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from memos.configs.graph_db import Neo4jGraphDBConfig
+
+
+# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+# Fixtures for unit tests (mocked Neo4j driver)
+# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+
+
+@pytest.fixture
+def shared_db_config():
+ """Shared-database multi-tenant config (use_multi_db=False)."""
+ return Neo4jGraphDBConfig(
+ uri="bolt://localhost:7687",
+ user="neo4j",
+ password="test",
+ db_name="test_db",
+ auto_create=False,
+ use_multi_db=False,
+ user_name="default_user",
+ embedding_dimension=3,
+ )
+
+
+@pytest.fixture
+def multi_db_config():
+ """Multi-database config โ no user_name filter in queries."""
+ return Neo4jGraphDBConfig(
+ uri="bolt://localhost:7687",
+ user="neo4j",
+ password="test",
+ db_name="test_db",
+ auto_create=False,
+ use_multi_db=True,
+ embedding_dimension=3,
+ )
+
+
+@pytest.fixture
+def shared_neo4j_db(shared_db_config):
+ with patch("neo4j.GraphDatabase") as mock_gd:
+ mock_driver = MagicMock()
+ mock_gd.driver.return_value = mock_driver
+ from memos.graph_dbs.neo4j import Neo4jGraphDB
+
+ db = Neo4jGraphDB(shared_db_config)
+ db.driver = mock_driver
+ yield db
+
+
+@pytest.fixture
+def multi_neo4j_db(multi_db_config):
+ with patch("neo4j.GraphDatabase") as mock_gd:
+ mock_driver = MagicMock()
+ mock_gd.driver.return_value = mock_driver
+ from memos.graph_dbs.neo4j import Neo4jGraphDB
+
+ db = Neo4jGraphDB(multi_db_config)
+ db.driver = mock_driver
+ yield db
+
+
+# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+# Unit tests: pre-filter vs ANN query paths
+# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+
+
+class TestVectorSearchPreFilter:
+ """Verify pre-filter path uses MATCH + vector.similarity.cosine()
+ and ANN path uses db.index.vector.queryNodes."""
+
+ def test_prefilter_with_scope(self, shared_neo4j_db):
+ """With scope filter, query should use MATCH + cosine similarity, not queryNodes."""
+ session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = []
+
+ shared_neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=5,
+ scope="LongTermMemory",
+ )
+
+ query = session_mock.run.call_args[0][0]
+ assert "MATCH (node:Memory)" in query
+ assert "vector.similarity.cosine(node.embedding, $embedding)" in query
+ assert "queryNodes" not in query
+
+ def test_prefilter_with_all_filters(self, shared_neo4j_db):
+ """With scope + status + user_name, all filters appear in WHERE before similarity."""
+ session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = []
+
+ shared_neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=10,
+ scope="LongTermMemory",
+ status="activated",
+ user_name="some_user",
+ )
+
+ query = session_mock.run.call_args[0][0]
+ assert "MATCH (node:Memory)" in query
+ assert "node.memory_type = $scope" in query
+ assert "node.status = $status" in query
+ assert "node.user_name = $user_name" in query
+ assert "vector.similarity.cosine" in query
+
+ def test_prefilter_includes_embedding_not_null(self, shared_neo4j_db):
+ """Pre-filter query should exclude nodes without embeddings."""
+ session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = []
+
+ shared_neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=5,
+ scope="LongTermMemory",
+ )
+
+ query = session_mock.run.call_args[0][0]
+ assert "node.embedding IS NOT NULL" in query
+
+ def test_prefilter_has_order_by_and_limit(self, shared_neo4j_db):
+ """Pre-filter results should be ordered by score and limited."""
+ session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = []
+
+ shared_neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=5,
+ scope="LongTermMemory",
+ )
+
+ query = session_mock.run.call_args[0][0]
+ assert "ORDER BY score DESC" in query
+ assert "LIMIT $top_k" in query
+ params = session_mock.run.call_args[0][1]
+ assert params["top_k"] == 5
+
+ def test_ann_path_without_filters(self, multi_neo4j_db):
+ """Without any filter, query should use queryNodes ANN index."""
+ session_mock = multi_neo4j_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = []
+
+ multi_neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=5,
+ )
+
+ query = session_mock.run.call_args[0][0]
+ assert "queryNodes" in query
+ assert "$top_k" in query
+ assert "MATCH (node:Memory)" not in query
+ params = session_mock.run.call_args[0][1]
+ assert params["top_k"] == 5
+
+ def test_ann_path_no_redundant_params(self, multi_neo4j_db):
+ """ANN path should only have embedding and top_k, nothing extra."""
+ session_mock = multi_neo4j_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = []
+
+ multi_neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=5,
+ )
+
+ params = session_mock.run.call_args[0][1]
+ assert set(params.keys()) == {"embedding", "top_k"}
+
+
+# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+# Unit tests: sources KeyError regression
+# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+
+
+class TestSourcesKeyErrorRegression:
+ """Verify that missing/None 'sources' key doesn't cause KeyError."""
+
+ def test_add_node_without_sources_key(self, shared_neo4j_db):
+ session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value
+
+ shared_neo4j_db.add_node(
+ id="test-node-1",
+ memory="test content",
+ metadata={
+ "memory_type": "WorkingMemory",
+ "embedding": [0.1, 0.2, 0.3],
+ "created_at": datetime.now(timezone.utc).isoformat(),
+ "updated_at": datetime.now(timezone.utc).isoformat(),
+ },
+ )
+
+ calls = session_mock.run.call_args_list
+ assert any("MERGE (n:Memory" in str(call) for call in calls)
+
+ def test_add_node_with_empty_sources(self, shared_neo4j_db):
+ _session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value
+
+ shared_neo4j_db.add_node(
+ id="test-node-2",
+ memory="test content",
+ metadata={
+ "memory_type": "WorkingMemory",
+ "embedding": [0.1, 0.2, 0.3],
+ "sources": [],
+ "created_at": datetime.now(timezone.utc).isoformat(),
+ "updated_at": datetime.now(timezone.utc).isoformat(),
+ },
+ )
+
+ def test_parse_node_without_sources_key(self, shared_neo4j_db):
+ result = shared_neo4j_db._parse_node(
+ {
+ "id": "node-1",
+ "memory": "hello",
+ "memory_type": "WorkingMemory",
+ "created_at": datetime.now(timezone.utc),
+ "updated_at": datetime.now(timezone.utc),
+ }
+ )
+ assert result["id"] == "node-1"
+ assert result["memory"] == "hello"
+
+
+# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+# Integration tests (require a running Neo4j 5.18+ with vector index)
+#
+# Activate by setting environment variables:
+# NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD
+#
+# Run:
+# pytest tests/graph_dbs/test_neo4j_vector_search.py -k Integration -v
+# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
+
+
+def _neo4j_package_available():
+ try:
+ import neo4j # noqa: F401
+
+ return True
+ except ImportError:
+ return False
+
+
+_neo4j_configured = _neo4j_package_available() and all(
+ os.getenv(k) for k in ("NEO4J_URI", "NEO4J_USER", "NEO4J_PASSWORD")
+)
+_TEST_RUN_ID = uuid.uuid4().hex[:8]
+_TARGET_USER = f"__test_target_{_TEST_RUN_ID}"
+_OTHER_USER_PREFIX = f"__test_other_{_TEST_RUN_ID}"
+
+
+def _make_unit_vector(
+ dim: int, dominant_axis: int, secondary_axis: int | None = None
+) -> list[float]:
+ """
+ Create a unit vector concentrated on one axis, optionally blended with a second.
+
+ Used to control cosine similarity in tests:
+ - Two vectors on the same axis โ cos_sim โ 1.0
+ - Orthogonal axes โ cos_sim โ 0.0
+ - Blended โ cos_sim โ 0.707
+ """
+ vec = [0.0] * dim
+ vec[dominant_axis % dim] = 1.0
+ if secondary_axis is not None:
+ vec[secondary_axis % dim] = 1.0
+ norm = math.sqrt(sum(x * x for x in vec))
+ return [x / norm for x in vec]
+
+
+@pytest.fixture(scope="module")
+def integration_config():
+ if not _neo4j_configured:
+ pytest.skip("Neo4j not configured (need NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)")
+ return Neo4jGraphDBConfig(
+ uri=os.getenv("NEO4J_URI"),
+ user=os.getenv("NEO4J_USER"),
+ password=os.getenv("NEO4J_PASSWORD"),
+ db_name=os.getenv("NEO4J_DB_NAME", "neo4j"),
+ auto_create=False,
+ use_multi_db=False,
+ user_name=f"__test_default_{_TEST_RUN_ID}",
+ embedding_dimension=int(os.getenv("EMBEDDING_DIMENSION", "1536")),
+ )
+
+
+@pytest.fixture(scope="module")
+def integration_db(integration_config):
+ from memos.graph_dbs.neo4j import Neo4jGraphDB
+
+ return Neo4jGraphDB(integration_config)
+
+
+@pytest.mark.skipif(not _neo4j_configured, reason="Neo4j not configured")
+class TestNeo4jPreFilterIntegration:
+ """
+ Integration test: pre-filtered vector search in a multi-user shared database.
+
+ Uses vector.similarity.cosine() with MATCH + WHERE to pre-filter by user,
+ guaranteeing that target user's nodes are always considered regardless of
+ how many other users' nodes exist in the database.
+ """
+
+ @pytest.fixture(scope="class", autouse=True)
+ def seed_and_cleanup(self, integration_db, integration_config):
+ """
+ Seed multi-user test data, then clean up.
+
+ - 50 "other" user nodes: embeddings along axis 0 โ cos_sim โ 1.0 with query
+ - 3 "target" user nodes: embeddings blended axis 0+1 โ cos_sim โ 0.707 with query
+
+ With pre-filtering, only the target user's 3 nodes are candidates for
+ similarity computation, so all 3 are always returned.
+ """
+ dim = integration_config.embedding_dimension
+ now = datetime.now(timezone.utc).isoformat()
+
+ for i in range(50):
+ other_user = f"{_OTHER_USER_PREFIX}_{i % 10}"
+ integration_db.add_node(
+ id=f"__test_other_{_TEST_RUN_ID}_{i}",
+ memory=f"Other user memory {i}",
+ metadata={
+ "memory_type": "LongTermMemory",
+ "status": "activated",
+ "embedding": _make_unit_vector(dim, dominant_axis=0),
+ "created_at": now,
+ "updated_at": now,
+ },
+ user_name=other_user,
+ )
+
+ for i in range(3):
+ integration_db.add_node(
+ id=f"__test_target_{_TEST_RUN_ID}_{i}",
+ memory=f"Target user memory {i}",
+ metadata={
+ "memory_type": "LongTermMemory",
+ "status": "activated",
+ "embedding": _make_unit_vector(dim, dominant_axis=0, secondary_axis=1),
+ "created_at": now,
+ "updated_at": now,
+ },
+ user_name=_TARGET_USER,
+ )
+
+ yield
+
+ integration_db.clear(user_name=_TARGET_USER)
+ for i in range(10):
+ integration_db.clear(user_name=f"{_OTHER_USER_PREFIX}_{i}")
+
+ def test_search_returns_all_target_user_results(self, integration_db, integration_config):
+ """Pre-filtering guarantees all target user nodes are found."""
+ dim = integration_config.embedding_dimension
+ query_vector = _make_unit_vector(dim, dominant_axis=0)
+
+ results = integration_db.search_by_embedding(
+ vector=query_vector,
+ top_k=3,
+ scope="LongTermMemory",
+ status="activated",
+ user_name=_TARGET_USER,
+ )
+
+ assert len(results) == 3, (
+ f"Pre-filter should return all 3 target user nodes, got {len(results)}. "
+ "This indicates pre-filtering is not working correctly."
+ )
+
+ def test_all_returned_ids_belong_to_target_user(self, integration_db, integration_config):
+ dim = integration_config.embedding_dimension
+ query_vector = _make_unit_vector(dim, dominant_axis=0)
+
+ results = integration_db.search_by_embedding(
+ vector=query_vector,
+ top_k=3,
+ scope="LongTermMemory",
+ status="activated",
+ user_name=_TARGET_USER,
+ )
+
+ for r in results:
+ assert r["id"].startswith(f"__test_target_{_TEST_RUN_ID}_"), (
+ f"Result {r['id']} does not belong to the target user"
+ )
+
+ def test_scores_are_positive(self, integration_db, integration_config):
+ dim = integration_config.embedding_dimension
+ query_vector = _make_unit_vector(dim, dominant_axis=0)
+
+ results = integration_db.search_by_embedding(
+ vector=query_vector,
+ top_k=3,
+ scope="LongTermMemory",
+ status="activated",
+ user_name=_TARGET_USER,
+ )
+
+ for r in results:
+ assert r["score"] > 0, f"Score should be positive, got {r['score']}"
diff --git a/tests/llms/test_minimax.py b/tests/llms/test_minimax.py
new file mode 100644
index 000000000..d984adcef
--- /dev/null
+++ b/tests/llms/test_minimax.py
@@ -0,0 +1,114 @@
+import unittest
+
+from types import SimpleNamespace
+from unittest.mock import MagicMock
+
+from memos.configs.llm import MinimaxLLMConfig
+from memos.llms.minimax import MinimaxLLM
+
+
+class TestMinimaxLLM(unittest.TestCase):
+ def test_minimax_llm_generate_with_and_without_think_prefix(self):
+ """Test MinimaxLLM generate method with and without tag removal."""
+
+ # Simulated full content including tag
+ full_content = "Hello from MiniMax!"
+ reasoning_content = "Thinking in progress..."
+
+ # Mock response object
+ mock_response = MagicMock()
+ mock_response.model_dump_json.return_value = '{"mock": "true"}'
+ mock_response.choices[0].message.content = full_content
+ mock_response.choices[0].message.reasoning_content = reasoning_content
+
+ # Config with think prefix preserved
+ config_with_think = MinimaxLLMConfig.model_validate(
+ {
+ "model_name_or_path": "MiniMax-M2.7",
+ "temperature": 0.7,
+ "max_tokens": 512,
+ "top_p": 0.9,
+ "api_key": "sk-test",
+ "api_base": "https://api.minimax.io/v1",
+ "remove_think_prefix": False,
+ }
+ )
+ llm_with_think = MinimaxLLM(config_with_think)
+ llm_with_think.client.chat.completions.create = MagicMock(return_value=mock_response)
+
+ output_with_think = llm_with_think.generate([{"role": "user", "content": "Hello"}])
+ self.assertEqual(output_with_think, f"{reasoning_content}{full_content}")
+
+ # Config with think tag removed
+ config_without_think = config_with_think.model_copy(update={"remove_think_prefix": True})
+ llm_without_think = MinimaxLLM(config_without_think)
+ llm_without_think.client.chat.completions.create = MagicMock(return_value=mock_response)
+
+ output_without_think = llm_without_think.generate([{"role": "user", "content": "Hello"}])
+ self.assertEqual(output_without_think, full_content)
+
+ def test_minimax_llm_generate_stream(self):
+ """Test MinimaxLLM generate_stream with content chunks."""
+
+ def make_chunk(delta_dict):
+ # Create a simulated stream chunk with delta fields
+ delta = SimpleNamespace(**delta_dict)
+ choice = SimpleNamespace(delta=delta)
+ return SimpleNamespace(choices=[choice])
+
+ # Simulate chunks: content only (MiniMax standard response)
+ mock_stream_chunks = [
+ make_chunk({"content": "Hello"}),
+ make_chunk({"content": ", "}),
+ make_chunk({"content": "MiniMax!"}),
+ ]
+
+ mock_chat_completions_create = MagicMock(return_value=iter(mock_stream_chunks))
+
+ config = MinimaxLLMConfig.model_validate(
+ {
+ "model_name_or_path": "MiniMax-M2.7",
+ "temperature": 0.7,
+ "max_tokens": 512,
+ "top_p": 0.9,
+ "api_key": "sk-test",
+ "api_base": "https://api.minimax.io/v1",
+ "remove_think_prefix": False,
+ }
+ )
+ llm = MinimaxLLM(config)
+ llm.client.chat.completions.create = mock_chat_completions_create
+
+ messages = [{"role": "user", "content": "Say hello"}]
+ streamed = list(llm.generate_stream(messages))
+ full_output = "".join(streamed)
+
+ self.assertEqual(full_output, "Hello, MiniMax!")
+
+ def test_minimax_llm_config_defaults(self):
+ """Test MinimaxLLMConfig default values."""
+ config = MinimaxLLMConfig.model_validate(
+ {
+ "model_name_or_path": "MiniMax-M2.7",
+ "api_key": "sk-test",
+ }
+ )
+ self.assertEqual(config.api_base, "https://api.minimax.io/v1")
+ self.assertEqual(config.temperature, 0.7)
+ self.assertEqual(config.max_tokens, 8192)
+
+ def test_minimax_llm_config_custom_values(self):
+ """Test MinimaxLLMConfig with custom values."""
+ config = MinimaxLLMConfig.model_validate(
+ {
+ "model_name_or_path": "MiniMax-M2.7-highspeed",
+ "api_key": "sk-test",
+ "api_base": "https://custom.api.minimax.io/v1",
+ "temperature": 0.5,
+ "max_tokens": 2048,
+ }
+ )
+ self.assertEqual(config.model_name_or_path, "MiniMax-M2.7-highspeed")
+ self.assertEqual(config.api_base, "https://custom.api.minimax.io/v1")
+ self.assertEqual(config.temperature, 0.5)
+ self.assertEqual(config.max_tokens, 2048)
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 9750af121..a1e423e4f 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -16,13 +16,13 @@
class TestExportOpenAPI:
"""Test the export_openapi function."""
- @patch("memos.api.start_api.app")
+ @patch("memos.cli.get_openapi_app")
@patch("builtins.open", new_callable=mock_open)
@patch("os.makedirs")
def test_export_openapi_success(self, mock_makedirs, mock_file, mock_app):
"""Test successful OpenAPI export."""
mock_openapi_data = {"openapi": "3.0.0", "info": {"title": "Test API"}}
- mock_app.openapi.return_value = mock_openapi_data
+ mock_app.return_value.openapi.return_value = mock_openapi_data
result = export_openapi("/test/path/openapi.json")
@@ -30,11 +30,11 @@ def test_export_openapi_success(self, mock_makedirs, mock_file, mock_app):
mock_makedirs.assert_called_once_with("/test/path", exist_ok=True)
mock_file.assert_called_once_with("/test/path/openapi.json", "w")
- @patch("memos.api.start_api.app")
+ @patch("memos.cli.get_openapi_app")
@patch("builtins.open", side_effect=OSError("Permission denied"))
def test_export_openapi_error(self, mock_file, mock_app):
"""Test OpenAPI export when file writing fails."""
- mock_app.openapi.return_value = {"test": "data"}
+ mock_app.return_value.openapi.return_value = {"test": "data"}
with pytest.raises(IOError):
export_openapi("/invalid/path/openapi.json")