diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..4a7682b --- /dev/null +++ b/.env.example @@ -0,0 +1,6 @@ +# 说明:.env 已配置 Git 忽略,仅本地生效;.env.example 为公开模板,不会被忽略 +# 使用方式:复制本文件并重命名为 .env,在 .env 中的 API_KEY 键填入真实 API Key +# 重要警告:严禁在 .env.example 中写入真实密钥并提交至远程仓库,防止 Token 泄露 +# 补充:仓库虽然开启了密钥泄露通知,但仅作兜底防护,请勿依赖该机制 + +API_KEY=sh-... diff --git a/.gitignore b/.gitignore index cada1bc..4c5b2ce 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target -/.trae \ No newline at end of file +/.trae +.env \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 31c3cb7..cc5cdae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7,10 +7,12 @@ name = "OmegaCode" version = "0.1.0-alpha" dependencies = [ "anyhow", + "bytes", "chrono", "clap", "colored", "crossterm 0.29.0", + "dotenv", "flate2", "flexi_logger", "futures", @@ -787,6 +789,12 @@ dependencies = [ "litrs", ] +[[package]] +name = "dotenv" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" + [[package]] name = "dunce" version = "1.0.5" @@ -2524,6 +2532,8 @@ dependencies = [ "rustls", "rustls-pki-types", "rustls-platform-verifier", + "serde", + "serde_json", "sync_wrapper", "tokio", "tokio-rustls", diff --git a/Cargo.toml b/Cargo.toml index e4c12a2..3eab6d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ anyhow = "1.0.102" tokio = { version = "1.52.3", features = ["full"] } colored = "3.1.1" indicatif = "0.18.4" -reqwest = { version = "0.13.3", features = ["stream"] } +reqwest = { version = "0.13.3", features = ["stream", "json"] } tokio-stream = "0.1.18" futures = "0.3.32" futures-util = "0.3.32" @@ -38,4 +38,6 @@ zip = "8.6.0" flate2 = "1.1.9" tar = "0.4.46" thiserror = "2.0.18" -hex = "0.4.3" \ No newline at end of file +hex = "0.4.3" +dotenv = "0.15.0" +bytes = "1.11.1" \ No newline at end of file diff --git a/src/core/chat/client.rs b/src/core/chat/client.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/core/chat/mod.rs b/src/core/chat/mod.rs new file mode 100644 index 0000000..0278328 --- /dev/null +++ b/src/core/chat/mod.rs @@ -0,0 +1,2 @@ +mod client; +mod provider; \ No newline at end of file diff --git a/src/core/chat/provider/anthropic/mod.rs b/src/core/chat/provider/anthropic/mod.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/core/chat/provider/common/mod.rs b/src/core/chat/provider/common/mod.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/core/chat/provider/deepseek/chat.rs b/src/core/chat/provider/deepseek/chat.rs new file mode 100644 index 0000000..861260b --- /dev/null +++ b/src/core/chat/provider/deepseek/chat.rs @@ -0,0 +1,117 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use super::{enums::*, message::ChatMessage, tool::*}; + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct StreamOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub include_usage: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct ResponseFormat { + pub r#type: ResponseFormatType, +} + +/// 聊天补全请求 +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct CreateChatCompletionRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +/// 聊天补全响应 +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, + pub created: i64, + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + pub choices: Vec, + pub usage: Usage, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_filter_results: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Choice { + pub index: i32, + pub message: ChatMessage, + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Logprobs { + pub content: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LogprobContent { + pub token: String, + pub logprob: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub bytes: Option>, + pub top_logprobs: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TopLogprob { + pub token: String, + pub logprob: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub bytes: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_tokens_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_tokens_details: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PromptTokensDetails { + pub cached_tokens: u32, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CompletionTokensDetails { + pub reasoning_tokens: u32, +} \ No newline at end of file diff --git a/src/core/chat/provider/deepseek/enums.rs b/src/core/chat/provider/deepseek/enums.rs new file mode 100644 index 0000000..e241678 --- /dev/null +++ b/src/core/chat/provider/deepseek/enums.rs @@ -0,0 +1,38 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum Role { + System, + User, + Assistant, + Tool, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ToolType { + Function, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ResponseFormatType { + Text, + JsonObject, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum FinishReason { + Stop, + Length, + ToolCalls, + ContentFilter, +} + +impl Default for ResponseFormatType { + fn default() -> Self { + Self::Text + } +} \ No newline at end of file diff --git a/src/core/chat/provider/deepseek/message.rs b/src/core/chat/provider/deepseek/message.rs new file mode 100644 index 0000000..d50b93a --- /dev/null +++ b/src/core/chat/provider/deepseek/message.rs @@ -0,0 +1,47 @@ +use serde::{Deserialize, Serialize}; +use super::enums::Role; + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum MessageContent { + Text(String), + Parts(Vec), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ContentPart { + pub r#type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_url: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ImageUrl { + pub url: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: Role, + pub content: MessageContent, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +impl Default for ChatMessage { + fn default() -> Self { + Self { + role: Role::User, + content: MessageContent::Text(String::new()), + name: None, + tool_call_id: None, + tool_calls: None, + } + } +} \ No newline at end of file diff --git a/src/core/chat/provider/deepseek/mod.rs b/src/core/chat/provider/deepseek/mod.rs new file mode 100644 index 0000000..46abd49 --- /dev/null +++ b/src/core/chat/provider/deepseek/mod.rs @@ -0,0 +1,6 @@ +mod chat; +mod model; +mod enums; +mod tool; +mod message; +mod test; \ No newline at end of file diff --git a/src/core/chat/provider/deepseek/model.rs b/src/core/chat/provider/deepseek/model.rs new file mode 100644 index 0000000..7cee7ef --- /dev/null +++ b/src/core/chat/provider/deepseek/model.rs @@ -0,0 +1,20 @@ +use serde::{Deserialize, Serialize}; + +/// GET /models 请求(无参数,空结构体) +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct ListModelsRequest {} + +/// 模型列表响应 +#[derive(Debug, Serialize, Deserialize)] +pub struct ListModelsResponse { + pub object: String, + pub data: Vec, +} + +/// 单个模型信息 +#[derive(Debug, Serialize, Deserialize)] +pub struct ModelInfo { + pub id: String, + pub object: String, + pub owned_by: String, +} \ No newline at end of file diff --git a/src/core/chat/provider/deepseek/test.rs b/src/core/chat/provider/deepseek/test.rs new file mode 100644 index 0000000..3456c68 --- /dev/null +++ b/src/core/chat/provider/deepseek/test.rs @@ -0,0 +1,281 @@ +// 修复:删除重复导入 +use crate::core::chat::provider::deepseek::{model::*, enums::*, message::*, tool::*, chat::*}; +use reqwest::Client; +use std::env; +use dotenv::dotenv; +use serde_json; + +const API_BASE: &str = "https://api.deepseek.com"; + +fn main() {} + +#[cfg(test)] +mod tests { + use super::*; + use futures_util::StreamExt; + + /// 获取 API Key(加载 .env 文件) + fn get_api_key() -> String { + dotenv().ok(); + env::var("API_KEY").expect("❌ 请在 .env 文件中设置 API_KEY 环境变量") + } + + // 1. 测试:获取模型列表 ✅ 修复:移除 /v1 前缀 + #[tokio::test] + async fn test_list_models() -> anyhow::Result<()> { + let client = Client::new(); + let auth = format!("Bearer {}", get_api_key()); + + let resp: ListModelsResponse = client + .get(format!("{API_BASE}/models")) // 修复点1 + .header("Authorization", &auth) + .send() + .await? + .json() + .await?; + + assert_eq!(resp.object, "list"); + assert!(!resp.data.is_empty()); + + println!("\n====================================="); + println!("📋 模型列表测试结果"); + println!("====================================="); + println!("{}", serde_json::to_string_pretty(&resp)?); + println!("✅ 模型列表测试通过,共 {} 个模型", resp.data.len()); + + Ok(()) + } + + // 2. 测试:普通对话 ✅ 修复:MessageContent 取值 + #[tokio::test] + async fn test_chat_normal() -> anyhow::Result<()> { + let client = Client::new(); + let auth = format!("Bearer {}", get_api_key()); + + let req = CreateChatCompletionRequest { + model: "deepseek-v4-flash".into(), + messages: vec![ChatMessage { + role: Role::User, + content: MessageContent::Text("你好".into()), + ..Default::default() + }], + max_tokens: Some(100), + ..Default::default() + }; + + let resp: ChatCompletionResponse = client + .post(format!("{API_BASE}/v1/chat/completions")) + .header("Authorization", &auth) + .json(&req) + .send() + .await? + .json() + .await?; + + assert!(!resp.choices.is_empty()); + assert_eq!(resp.model, "deepseek-v4-flash"); + + // 修复点2:直接获取 MessageContent 文本 + let reply = match &resp.choices[0].message.content { + MessageContent::Text(text) => text, + _ => "非文本消息", + }; + + println!("\n====================================="); + println!("💬 普通对话测试结果"); + println!("====================================="); + println!("AI 回复:{}", reply); + println!("完整响应:{}", serde_json::to_string_pretty(&resp)?); + println!("✅ 普通对话测试通过"); + + Ok(()) + } + + // 3. 测试:深度思考模型 ✅ 修复:删除不存在的 reasoning_content 字段 + #[tokio::test] + async fn test_chat_reasoner() -> anyhow::Result<()> { + let client = Client::new(); + let auth = format!("Bearer {}", get_api_key()); + + let req = CreateChatCompletionRequest { + model: "deepseek-reasoner".into(), + messages: vec![ChatMessage { + role: Role::User, + content: MessageContent::Text("1024*1024等于多少?".into()), + ..Default::default() + }], + ..Default::default() + }; + + let resp: ChatCompletionResponse = client + .post(format!("{API_BASE}/v1/chat/completions")) + .header("Authorization", &auth) + .json(&req) + .send() + .await? + .json() + .await?; + + assert!(!resp.choices.is_empty()); + + // 修复点3:移除无此字段的打印 + let reply = match &resp.choices[0].message.content { + MessageContent::Text(text) => text, + _ => "非文本消息", + }; + + println!("\n====================================="); + println!("🤯 深度思考模型测试结果"); + println!("====================================="); + println!("最终回复:{}", reply); + println!("完整响应:{}", serde_json::to_string_pretty(&resp)?); + println!("✅ 思考模型测试通过"); + + Ok(()) + } + + // 4. 测试:JSON 格式输出 ✅ 修复:MessageContent 取值 + #[tokio::test] + async fn test_chat_json_format() -> anyhow::Result<()> { + let client = Client::new(); + let auth = format!("Bearer {}", get_api_key()); + + let req = CreateChatCompletionRequest { + model: "deepseek-v4-flash".into(), + messages: vec![ChatMessage { + role: Role::User, + content: MessageContent::Text("返回一个包含name和age的JSON".into()), + ..Default::default() + }], + response_format: Some(ResponseFormat::default()), + ..Default::default() + }; + + let resp: ChatCompletionResponse = client + .post(format!("{API_BASE}/v1/chat/completions")) + .header("Authorization", &auth) + .json(&req) + .send() + .await? + .json() + .await?; + + assert!(!resp.choices.is_empty()); + + let json_result = match &resp.choices[0].message.content { + MessageContent::Text(text) => text, + _ => "非文本消息", + }; + + println!("\n====================================="); + println!("📄 JSON 格式输出测试结果"); + println!("====================================="); + println!("JSON 内容:{}", json_result); + println!("完整响应:{}", serde_json::to_string_pretty(&resp)?); + println!("✅ JSON输出测试通过"); + + Ok(()) + } + + // 5. 测试:工具调用 ✅ 无报错,保持原样 + #[tokio::test] + async fn test_chat_tool_call() -> anyhow::Result<()> { + let client = Client::new(); + let auth = format!("Bearer {}", get_api_key()); + + let req = CreateChatCompletionRequest { + model: "deepseek-v4-flash".into(), + messages: vec![ChatMessage { + role: Role::User, + content: MessageContent::Text("查询上海天气".into()), + ..Default::default() + }], + tools: Some(vec![Tool { + r#type: ToolType::Function, + function: FunctionObject { + name: "get_weather".into(), + description: Some("获取城市天气".into()), + parameters: Some(serde_json::to_value( + serde_json::json!({ + "type": "object", + "properties": { + "city": {"type": "string"} + }, + "required": ["city"] + }) + )?), + }, + }]), + tool_choice: Some(ToolChoice::Strategy("auto".into())), + ..Default::default() + }; + + let resp: ChatCompletionResponse = client + .post(format!("{API_BASE}/v1/chat/completions")) + .header("Authorization", &auth) + .json(&req) + .send() + .await? + .json() + .await?; + + assert!(resp.choices[0].message.tool_calls.is_some()); + + let tool_call = &resp.choices[0].message.tool_calls.as_ref().unwrap()[0]; + println!("\n====================================="); + println!("🔧 工具调用测试结果"); + println!("====================================="); + println!("调用函数:{}", tool_call.function.name); + println!("调用参数:{}", tool_call.function.arguments); + println!("完整响应:{}", serde_json::to_string_pretty(&resp)?); + println!("✅ 工具调用测试通过"); + + Ok(()) + } + + // 6. 测试:流式输出 ✅ 无报错,保持原样 + #[tokio::test] + async fn test_chat_stream() -> anyhow::Result<()> { + let client = Client::new(); + let auth = format!("Bearer {}", get_api_key()); + + let req = CreateChatCompletionRequest { + model: "deepseek-v4-flash".into(), + messages: vec![ChatMessage { + role: Role::User, + content: MessageContent::Text("介绍Rust".into()), + ..Default::default() + }], + stream: Some(true), + stream_options: Some(StreamOptions { + include_usage: Some(true), + }), + ..Default::default() + }; + + let mut stream = client + .post(format!("{API_BASE}/v1/chat/completions")) + .header("Authorization", &auth) + .json(&req) + .send() + .await? + .bytes_stream(); + + println!("\n====================================="); + println!("🌊 流式输出测试结果(实时打印)"); + println!("====================================="); + + let mut full_content = String::new(); + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result?; + let s = String::from_utf8_lossy(&chunk); + print!("{}", s); + full_content.push_str(&s); + } + + assert!(!full_content.is_empty()); + println!("\n✅ 流式输出测试通过"); + + Ok(()) + } +} \ No newline at end of file diff --git a/src/core/chat/provider/deepseek/tool.rs b/src/core/chat/provider/deepseek/tool.rs new file mode 100644 index 0000000..2bc3c1a --- /dev/null +++ b/src/core/chat/provider/deepseek/tool.rs @@ -0,0 +1,49 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use super::enums::ToolType; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Tool { + pub r#type: ToolType, + pub function: FunctionObject, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FunctionObject { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + pub r#type: ToolType, + pub function: ToolCallFunction, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ToolCallFunction { + pub name: String, + pub arguments: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + Strategy(String), + Tool(ToolChoiceObject), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ToolChoiceObject { + pub r#type: ToolType, + pub function: ToolChoiceFunction, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ToolChoiceFunction { + pub name: String, +} \ No newline at end of file diff --git a/src/core/chat/provider/mod.rs b/src/core/chat/provider/mod.rs new file mode 100644 index 0000000..556dda3 --- /dev/null +++ b/src/core/chat/provider/mod.rs @@ -0,0 +1,4 @@ +mod anthropic; +mod common; +mod openai; +mod deepseek; \ No newline at end of file diff --git a/src/core/chat/provider/openai/mod.rs b/src/core/chat/provider/openai/mod.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/core/mod.rs b/src/core/mod.rs index f5562e2..6ec46a2 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -2,4 +2,5 @@ pub mod action; pub mod event; pub mod update; pub mod context; -pub mod db; \ No newline at end of file +pub mod db; +pub mod chat; \ No newline at end of file