diff --git a/charts/openab/templates/configmap.yaml b/charts/openab/templates/configmap.yaml index 12576a8e..32106c2b 100644 --- a/charts/openab/templates/configmap.yaml +++ b/charts/openab/templates/configmap.yaml @@ -167,6 +167,9 @@ data: api_key = "${STT_API_KEY}" model = {{ ($cfg.stt).model | default "whisper-large-v3-turbo" | toJson }} base_url = {{ ($cfg.stt).baseUrl | default "https://api.groq.com/openai/v1" | toJson }} + {{- if hasKey ($cfg.stt | default dict) "echoTranscript" }} + echo_transcript = {{ ($cfg.stt).echoTranscript }} + {{- end }} {{- end }} {{- if ($cfg.gateway).enabled }} {{- if not ($cfg.gateway).url }} diff --git a/charts/openab/values.yaml b/charts/openab/values.yaml index 4314d2df..81a5e706 100644 --- a/charts/openab/values.yaml +++ b/charts/openab/values.yaml @@ -222,6 +222,9 @@ agents: apiKey: "" model: "whisper-large-v3-turbo" baseUrl: "https://api.groq.com/openai/v1" + # Echo the transcribed text back to the thread before the agent reply + # so users can verify STT accuracy. Default: false (opt-in). + echoTranscript: false gateway: enabled: false # set to true + provide url to enable the [gateway] config block deploy: true # set to false to skip Gateway Deployment/Service (config-only mode) diff --git a/docs/config-reference.md b/docs/config-reference.md index 622dd7d3..9ddaf40f 100644 --- a/docs/config-reference.md +++ b/docs/config-reference.md @@ -204,6 +204,7 @@ Speech-to-text transcription for voice messages. Uses an OpenAI-compatible `/aud | `api_key` | string | `""` | API key for the STT service. When empty and `base_url` contains `groq.com`, the `GROQ_API_KEY` environment variable is used automatically. For local servers, use `api_key = "not-needed"`. | | `model` | string | `"whisper-large-v3-turbo"` | Model name to use for transcription. | | `base_url` | string | `"https://api.groq.com/openai/v1"` | Base URL of the STT API. Any OpenAI-compatible `/audio/transcriptions` endpoint works. | +| `echo_transcript` | bool | `false` | When set to `true` and STT runs, post a `> 🎤 ` message to the thread before the agent reply so users can verify what was heard. Failures show `(transcription failed)` and add a ⚠️ reaction to the original message. | --- diff --git a/docs/stt.md b/docs/stt.md index 5ee8fa48..202f9678 100644 --- a/docs/stt.md +++ b/docs/stt.md @@ -50,6 +50,7 @@ enabled = true # default: false api_key = "${GROQ_API_KEY}" # required for cloud providers model = "whisper-large-v3-turbo" # default base_url = "https://api.groq.com/openai/v1" # default +echo_transcript = true # default: false (opt-in) ``` | Field | Required | Default | Description | @@ -58,6 +59,7 @@ base_url = "https://api.groq.com/openai/v1" # default | `api_key` | no* | — | API key for the STT provider. *Auto-detected from `GROQ_API_KEY` env var if not set. For local servers, use any non-empty string (e.g. `"not-needed"`). | | `model` | no | `whisper-large-v3-turbo` | Whisper model name. Varies by provider. | | `base_url` | no | `https://api.groq.com/openai/v1` | OpenAI-compatible API base URL. | +| `echo_transcript` | no | `false` | When set to `true` and STT runs, post a `> 🎤 ` message to the thread before the agent reply so users can verify what was heard. Failures show `(transcription failed)` and add a ⚠️ reaction to the original message. | ## Deployment Options @@ -147,6 +149,13 @@ helm upgrade openab openab/openab \ --set agents.kiro.stt.baseUrl=https://api.groq.com/openai/v1 ``` +```bash +helm upgrade openab openab/openab \ + --set agents.kiro.stt.enabled=true \ + --set agents.kiro.stt.apiKey=gsk_xxx \ + --set agents.kiro.stt.echoTranscript=true # opt in to transcript echo +``` + ## Disabling STT Omit the `[stt]` section entirely, or set: diff --git a/src/acp/connection.rs b/src/acp/connection.rs index f49c0f50..c1b36a47 100644 --- a/src/acp/connection.rs +++ b/src/acp/connection.rs @@ -1,4 +1,6 @@ -use crate::acp::protocol::{ConfigOption, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, parse_config_options}; +use crate::acp::protocol::{ + parse_config_options, ConfigOption, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, +}; use anyhow::{anyhow, Result}; use serde_json::{json, Value}; use std::collections::HashMap; @@ -10,7 +12,6 @@ use tokio::sync::{mpsc, oneshot, Mutex}; use tokio::task::JoinHandle; use tracing::{debug, error, info}; - /// Pick the most permissive selectable permission option from ACP options. fn pick_best_option(options: &[Value]) -> Option { let mut fallback: Option<&Value> = None; @@ -187,20 +188,39 @@ impl AcpConnection { // Preserve the real HOME so agents can find OAuth/auth files (~/.codex, // ~/.claude, ~/.config/gh, etc.). working_dir is already set via // current_dir() above and is not necessarily the user's home directory. - cmd.env("HOME", std::env::var("HOME").unwrap_or_else(|_| working_dir.into())); - cmd.env("PATH", std::env::var("PATH").unwrap_or_else(|_| "/usr/local/bin:/usr/bin:/bin".into())); + cmd.env( + "HOME", + std::env::var("HOME").unwrap_or_else(|_| working_dir.into()), + ); + cmd.env( + "PATH", + std::env::var("PATH").unwrap_or_else(|_| "/usr/local/bin:/usr/bin:/bin".into()), + ); #[cfg(unix)] { - cmd.env("USER", std::env::var("USER").unwrap_or_else(|_| "agent".into())); + cmd.env( + "USER", + std::env::var("USER").unwrap_or_else(|_| "agent".into()), + ); } #[cfg(windows)] { // Windows requires SystemRoot for DLL loading and basic OS functionality. // USERPROFILE is the Windows equivalent of HOME. - cmd.env("USERPROFILE", std::env::var("USERPROFILE").unwrap_or_else(|_| working_dir.into())); - cmd.env("USERNAME", std::env::var("USERNAME").unwrap_or_else(|_| "agent".into())); - if let Ok(v) = std::env::var("SystemRoot") { cmd.env("SystemRoot", v); } - if let Ok(v) = std::env::var("SystemDrive") { cmd.env("SystemDrive", v); } + cmd.env( + "USERPROFILE", + std::env::var("USERPROFILE").unwrap_or_else(|_| working_dir.into()), + ); + cmd.env( + "USERNAME", + std::env::var("USERNAME").unwrap_or_else(|_| "agent".into()), + ); + if let Ok(v) = std::env::var("SystemRoot") { + cmd.env("SystemRoot", v); + } + if let Ok(v) = std::env::var("SystemDrive") { + cmd.env("SystemDrive", v); + } } for (k, v) in env { cmd.env(k, expand_env(v)); @@ -223,8 +243,7 @@ impl AcpConnection { let mut proc = cmd .spawn() .map_err(|e| anyhow!("failed to spawn {command}: {e}"))?; - let child_pgid = proc.id() - .and_then(|pid| i32::try_from(pid).ok()); + let child_pgid = proc.id().and_then(|pid| i32::try_from(pid).ok()); let stdout = proc.stdout.take().ok_or_else(|| anyhow!("no stdout"))?; let stdin = proc.stdin.take().ok_or_else(|| anyhow!("no stdin"))?; @@ -403,19 +422,22 @@ impl AcpConnection { .and_then(|c| c.get("loadSession")) .and_then(|v| v.as_bool()) .unwrap_or(false); - info!(agent = agent_name, load_session = self.supports_load_session, "initialized"); + info!( + agent = agent_name, + load_session = self.supports_load_session, + "initialized" + ); Ok(()) } pub async fn session_new(&mut self, cwd: &str) -> Result { let resp = self - .send_request( - "session/new", - Some(json!({"cwd": cwd, "mcpServers": []})), - ) + .send_request("session/new", Some(json!({"cwd": cwd, "mcpServers": []}))) .await?; - let session_id = resp.result.as_ref() + let session_id = resp + .result + .as_ref() .and_then(|r| r.get("sessionId")) .and_then(|s| s.as_str()) .ok_or_else(|| anyhow!("no sessionId in session/new response"))? @@ -434,7 +456,11 @@ impl AcpConnection { /// Set a config option (e.g. model, mode) via ACP session/set_config_option. /// Returns the updated list of all config options. - pub async fn set_config_option(&mut self, config_id: &str, value: &str) -> Result> { + pub async fn set_config_option( + &mut self, + config_id: &str, + value: &str, + ) -> Result> { let session_id = self .acp_session_id .as_ref() @@ -462,7 +488,10 @@ impl AcpConnection { Err(_) => { // Fall back: send as a slash command (e.g. "/model claude-sonnet-4") let cmd = format!("/{config_id} {value}"); - info!(cmd, "set_config_option not supported, falling back to prompt"); + info!( + cmd, + "set_config_option not supported, falling back to prompt" + ); let _resp = self .send_request( "session/prompt", @@ -503,10 +532,7 @@ impl AcpConnection { let id = self.next_id(); // Convert content blocks to JSON - let prompt_json: Vec = content_blocks - .iter() - .map(|b| b.to_json()) - .collect(); + let prompt_json: Vec = content_blocks.iter().map(|b| b.to_json()).collect(); let req = JsonRpcRequest::new( id, @@ -572,11 +598,15 @@ impl AcpConnection { #[cfg(unix)] { // Stage 1: SIGTERM the process group - unsafe { libc::kill(-pgid, libc::SIGTERM); } + unsafe { + libc::kill(-pgid, libc::SIGTERM); + } // Stage 2: SIGKILL after brief grace (std::thread survives runtime shutdown) std::thread::spawn(move || { std::thread::sleep(std::time::Duration::from_millis(1500)); - unsafe { libc::kill(-pgid, libc::SIGKILL); } + unsafe { + libc::kill(-pgid, libc::SIGKILL); + } }); } #[cfg(not(unix))] diff --git a/src/acp/mod.rs b/src/acp/mod.rs index c67cad82..f7d0141e 100644 --- a/src/acp/mod.rs +++ b/src/acp/mod.rs @@ -2,6 +2,6 @@ pub mod connection; pub mod pool; pub mod protocol; +pub use connection::ContentBlock; pub use pool::SessionPool; pub use protocol::{classify_notification, AcpEvent}; -pub use connection::ContentBlock; diff --git a/src/acp/pool.rs b/src/acp/pool.rs index a146abb0..6ccd3631 100644 --- a/src/acp/pool.rs +++ b/src/acp/pool.rs @@ -32,12 +32,7 @@ pub struct SessionPool { mapping_path: PathBuf, } -type EvictionCandidate = ( - String, - Arc>, - Instant, - Option, -); +type EvictionCandidate = (String, Arc>, Instant, Option); fn remove_if_same_handle( map: &mut HashMap>>, @@ -54,10 +49,7 @@ fn remove_if_same_handle( } } -fn get_or_insert_gate( - map: &mut HashMap>>, - key: &str, -) -> Arc> { +fn get_or_insert_gate(map: &mut HashMap>>, key: &str) -> Arc> { map.entry(key.to_string()) .or_insert_with(|| Arc::new(Mutex::new(()))) .clone() @@ -104,7 +96,9 @@ impl SessionPool { } }; let tmp = self.mapping_path.with_extension("json.tmp"); - if let Err(e) = std::fs::write(&tmp, &data).and_then(|_| std::fs::rename(&tmp, &self.mapping_path)) { + if let Err(e) = + std::fs::write(&tmp, &data).and_then(|_| std::fs::rename(&tmp, &self.mapping_path)) + { warn!(path = %self.mapping_path.display(), error = %e, "failed to persist thread mapping"); } } @@ -157,7 +151,12 @@ impl SessionPool { skipped_locked_candidates += 1; continue; }; - let candidate = (key, conn_handle, conn.last_active, conn.acp_session_id.clone()); + let candidate = ( + key, + conn_handle, + conn.last_active, + conn.acp_session_id.clone(), + ); match &eviction_candidate { Some((_, _, oldest_last_active, _)) if candidate.2 >= *oldest_last_active => {} _ => eviction_candidate = Some(candidate), @@ -250,7 +249,9 @@ impl SessionPool { state.active.insert(thread_id.to_string(), new_conn); self.save_mapping(&state.suspended); if !cancel_session_id.is_empty() { - state.cancel_handles.insert(thread_id.to_string(), (cancel_handle, cancel_session_id)); + state + .cancel_handles + .insert(thread_id.to_string(), (cancel_handle, cancel_session_id)); } Ok(()) } @@ -260,7 +261,9 @@ impl SessionPool { where F: for<'a> FnOnce( &'a mut AcpConnection, - ) -> std::pin::Pin> + Send + 'a>>, + ) -> std::pin::Pin< + Box> + Send + 'a>, + >, { let conn = { let state = self.state.read().await; @@ -311,7 +314,10 @@ impl SessionPool { pub async fn cancel_session(&self, thread_id: &str) -> Result<()> { let (stdin, session_id) = { let state = self.state.read().await; - state.cancel_handles.get(thread_id).cloned() + state + .cancel_handles + .get(thread_id) + .cloned() .ok_or_else(|| anyhow!("no session for thread {thread_id}"))? }; let data = serde_json::to_string(&serde_json::json!({ @@ -414,7 +420,11 @@ impl SessionPool { // awaiting a connection lock). let snapshot: Vec<(String, Arc>)> = { let state = self.state.read().await; - state.active.iter().map(|(k, v)| (k.clone(), Arc::clone(v))).collect() + state + .active + .iter() + .map(|(k, v)| (k.clone(), Arc::clone(v))) + .collect() }; let mut session_ids: Vec<(String, String)> = Vec::new(); diff --git a/src/acp/protocol.rs b/src/acp/protocol.rs index 25cfb937..40dfdf07 100644 --- a/src/acp/protocol.rs +++ b/src/acp/protocol.rs @@ -14,7 +14,12 @@ pub struct JsonRpcRequest { impl JsonRpcRequest { pub fn new(id: u64, method: impl Into, params: Option) -> Self { - Self { jsonrpc: "2.0", id, method: method.into(), params } + Self { + jsonrpc: "2.0", + id, + method: method.into(), + params, + } } } @@ -27,7 +32,11 @@ pub struct JsonRpcResponse { impl JsonRpcResponse { pub fn new(id: u64, result: Value) -> Self { - Self { jsonrpc: "2.0", id, result } + Self { + jsonrpc: "2.0", + id, + result, + } } } @@ -95,17 +104,26 @@ pub fn parse_config_options(result: &Value) -> Vec { let mut options = Vec::new(); if let Some(models) = result.get("models") { - let current = models.get("currentModelId").and_then(|v| v.as_str()).unwrap_or(""); + let current = models + .get("currentModelId") + .and_then(|v| v.as_str()) + .unwrap_or(""); if let Some(available) = models.get("availableModels").and_then(|v| v.as_array()) { let values: Vec = available .iter() .filter_map(|m| { - let id = m.get("modelId").or_else(|| m.get("id")).and_then(|v| v.as_str())?; + let id = m + .get("modelId") + .or_else(|| m.get("id")) + .and_then(|v| v.as_str())?; let name = m.get("name").and_then(|v| v.as_str()).unwrap_or(id); Some(ConfigOptionValue { value: id.to_string(), name: name.to_string(), - description: m.get("description").and_then(|v| v.as_str()).map(String::from), + description: m + .get("description") + .and_then(|v| v.as_str()) + .map(String::from), }) }) .collect(); @@ -124,7 +142,10 @@ pub fn parse_config_options(result: &Value) -> Vec { } if let Some(modes) = result.get("modes") { - let current = modes.get("currentModeId").and_then(|v| v.as_str()).unwrap_or(""); + let current = modes + .get("currentModeId") + .and_then(|v| v.as_str()) + .unwrap_or(""); if let Some(available) = modes.get("availableModes").and_then(|v| v.as_array()) { let values: Vec = available .iter() @@ -134,7 +155,10 @@ pub fn parse_config_options(result: &Value) -> Vec { Some(ConfigOptionValue { value: id.to_string(), name: name.to_string(), - description: m.get("description").and_then(|v| v.as_str()).map(String::from), + description: m + .get("description") + .and_then(|v| v.as_str()) + .map(String::from), }) }) .collect(); @@ -161,9 +185,18 @@ pub fn parse_config_options(result: &Value) -> Vec { pub enum AcpEvent { Text(String), Thinking, - ToolStart { id: String, title: String }, - ToolDone { id: String, title: String, status: String }, - ConfigUpdate { options: Vec }, + ToolStart { + id: String, + title: String, + }, + ToolDone { + id: String, + title: String, + status: String, + }, + ConfigUpdate { + options: Vec, + }, Status, } @@ -190,18 +223,32 @@ pub fn classify_notification(msg: &JsonRpcMessage) -> Option { let text = update.get("content")?.get("text")?.as_str()?; Some(AcpEvent::Text(text.to_string())) } - "agent_thought_chunk" => { - Some(AcpEvent::Thinking) - } + "agent_thought_chunk" => Some(AcpEvent::Thinking), "tool_call" => { - let title = update.get("title").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let title = update + .get("title") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); Some(AcpEvent::ToolStart { id: tool_id, title }) } "tool_call_update" => { - let title = update.get("title").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let status = update.get("status").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let title = update + .get("title") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let status = update + .get("status") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); if status == "completed" || status == "failed" { - Some(AcpEvent::ToolDone { id: tool_id, title, status }) + Some(AcpEvent::ToolDone { + id: tool_id, + title, + status, + }) } else { Some(AcpEvent::ToolStart { id: tool_id, title }) } diff --git a/src/adapter.rs b/src/adapter.rs index 89b2ae4d..106cd47b 100644 --- a/src/adapter.rs +++ b/src/adapter.rs @@ -306,7 +306,15 @@ impl AdapterRouter { reactions: Arc, other_bot_present: bool, ) -> Result<()> { - self.stream_prompt_blocks(adapter, thread_key, content_blocks, thread_channel, reactions, other_bot_present).await + self.stream_prompt_blocks( + adapter, + thread_key, + content_blocks, + thread_channel, + reactions, + other_bot_present, + ) + .await } /// Drive one ACP turn with the given pre-packed ContentBlocks. diff --git a/src/bot_turns.rs b/src/bot_turns.rs index 9f031d14..92ff7e72 100644 --- a/src/bot_turns.rs +++ b/src/bot_turns.rs @@ -34,7 +34,10 @@ pub struct BotTurnTracker { impl BotTurnTracker { pub fn new(soft_limit: u32) -> Self { - Self { soft_limit, counts: HashMap::new() } + Self { + soft_limit, + counts: HashMap::new(), + } } pub fn on_bot_message(&mut self, thread_id: &str) -> TurnResult { @@ -307,12 +310,18 @@ mod tests { assert_eq!(t.classify_bot_message("t1"), TurnAction::Continue); assert!(matches!( t.classify_bot_message("t1"), - TurnAction::WarnAndStop { severity: TurnSeverity::Soft, .. }, + TurnAction::WarnAndStop { + severity: TurnSeverity::Soft, + .. + }, )); assert_eq!(t.classify_bot_message("t2"), TurnAction::Continue); assert!(matches!( t.classify_bot_message("t2"), - TurnAction::WarnAndStop { severity: TurnSeverity::Soft, .. }, + TurnAction::WarnAndStop { + severity: TurnSeverity::Soft, + .. + }, )); } @@ -333,7 +342,11 @@ mod tests { assert_eq!(t.classify_bot_message("t1"), TurnAction::Continue); assert!(matches!( t.classify_bot_message("t1"), - TurnAction::WarnAndStop { severity: TurnSeverity::Soft, turns: 2, .. }, + TurnAction::WarnAndStop { + severity: TurnSeverity::Soft, + turns: 2, + .. + }, )); } } diff --git a/src/config.rs b/src/config.rs index 574b2660..dd56484d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -57,7 +57,10 @@ impl<'de> Deserialize<'de> for AllowBots { "off" | "none" | "false" => Ok(Self::Off), "mentions" => Ok(Self::Mentions), "all" | "true" => Ok(Self::All), - other => Err(serde::de::Error::unknown_variant(other, &["off", "mentions", "all"])), + other => Err(serde::de::Error::unknown_variant( + other, + &["off", "mentions", "all"], + )), } } } @@ -102,6 +105,10 @@ pub struct SttConfig { pub model: String, #[serde(default = "default_stt_base_url")] pub base_url: String, + /// Echo the transcribed text back to the thread (no mentions) before + /// dispatching the prompt to the agent. Lets users verify STT accuracy. + #[serde(default = "default_echo_transcript")] + pub echo_transcript: bool, } impl Default for SttConfig { @@ -111,12 +118,20 @@ impl Default for SttConfig { api_key: String::new(), model: default_stt_model(), base_url: default_stt_base_url(), + echo_transcript: default_echo_transcript(), } } } -fn default_stt_model() -> String { "whisper-large-v3-turbo".into() } -fn default_stt_base_url() -> String { "https://api.groq.com/openai/v1".into() } +fn default_stt_model() -> String { + "whisper-large-v3-turbo".into() +} +fn default_stt_base_url() -> String { + "https://api.groq.com/openai/v1".into() +} +fn default_echo_transcript() -> bool { + false +} #[derive(Debug, Deserialize)] pub struct DiscordConfig { @@ -161,9 +176,15 @@ pub struct DiscordConfig { pub max_batch_tokens: usize, } -fn default_max_bot_turns() -> u32 { 100 } -fn default_max_buffered_messages() -> usize { 10 } -fn default_max_batch_tokens() -> usize { 24_000 } +fn default_max_bot_turns() -> u32 { + 100 +} +fn default_max_buffered_messages() -> usize { + 10 +} +fn default_max_batch_tokens() -> usize { + 24_000 +} /// Controls whether the bot responds to user messages in threads without @mention. /// @@ -188,7 +209,10 @@ impl<'de> Deserialize<'de> for AllowUsers { "involved" => Ok(Self::Involved), "mentions" => Ok(Self::Mentions), "multibot_mentions" => Ok(Self::MultibotMentions), - other => Err(serde::de::Error::unknown_variant(other, &["involved", "mentions", "multibot-mentions"])), + other => Err(serde::de::Error::unknown_variant( + other, + &["involved", "mentions", "multibot-mentions"], + )), } } } @@ -315,9 +339,15 @@ pub struct CronJobConfig { pub timezone: String, } -fn default_cron_platform() -> String { "discord".into() } -fn default_cron_sender() -> String { "openab-cron".into() } -fn default_cron_timezone() -> String { "UTC".into() } +fn default_cron_platform() -> String { + "discord".into() +} +fn default_cron_sender() -> String { + "openab-cron".into() +} +fn default_cron_timezone() -> String { + "UTC".into() +} /// Controls how tool calls are rendered in chat messages. /// @@ -339,7 +369,10 @@ impl<'de> Deserialize<'de> for ToolDisplay { "full" => Ok(Self::Full), "compact" => Ok(Self::Compact), "none" | "off" | "hidden" => Ok(Self::None), - other => Err(serde::de::Error::unknown_variant(other, &["full", "compact", "none"])), + other => Err(serde::de::Error::unknown_variant( + other, + &["full", "compact", "none"], + )), } } } @@ -392,28 +425,63 @@ pub struct ReactionTiming { // --- defaults --- -fn default_working_dir() -> String { "/tmp".into() } -fn default_max_sessions() -> usize { 10 } -fn default_ttl_hours() -> u64 { 4 } -fn default_true() -> bool { true } - -fn emoji_queued() -> String { "👀".into() } -fn emoji_thinking() -> String { "🤔".into() } -fn emoji_tool() -> String { "🔥".into() } -fn emoji_coding() -> String { "👨‍💻".into() } -fn emoji_web() -> String { "⚡".into() } -fn emoji_done() -> String { "🆗".into() } -fn emoji_error() -> String { "😱".into() } - -fn default_debounce_ms() -> u64 { 700 } -fn default_stall_soft_ms() -> u64 { 10_000 } -fn default_stall_hard_ms() -> u64 { 30_000 } -fn default_done_hold_ms() -> u64 { 1_500 } -fn default_error_hold_ms() -> u64 { 2_500 } +fn default_working_dir() -> String { + "/tmp".into() +} +fn default_max_sessions() -> usize { + 10 +} +fn default_ttl_hours() -> u64 { + 4 +} +fn default_true() -> bool { + true +} + +fn emoji_queued() -> String { + "👀".into() +} +fn emoji_thinking() -> String { + "🤔".into() +} +fn emoji_tool() -> String { + "🔥".into() +} +fn emoji_coding() -> String { + "👨‍💻".into() +} +fn emoji_web() -> String { + "⚡".into() +} +fn emoji_done() -> String { + "🆗".into() +} +fn emoji_error() -> String { + "😱".into() +} + +fn default_debounce_ms() -> u64 { + 700 +} +fn default_stall_soft_ms() -> u64 { + 10_000 +} +fn default_stall_hard_ms() -> u64 { + 30_000 +} +fn default_done_hold_ms() -> u64 { + 1_500 +} +fn default_error_hold_ms() -> u64 { + 2_500 +} impl Default for PoolConfig { fn default() -> Self { - Self { max_sessions: default_max_sessions(), session_ttl_hours: default_ttl_hours() } + Self { + max_sessions: default_max_sessions(), + session_ttl_hours: default_ttl_hours(), + } } } @@ -432,8 +500,13 @@ impl Default for ReactionsConfig { impl Default for ReactionEmojis { fn default() -> Self { Self { - queued: emoji_queued(), thinking: emoji_thinking(), tool: emoji_tool(), - coding: emoji_coding(), web: emoji_web(), done: emoji_done(), error: emoji_error(), + queued: emoji_queued(), + thinking: emoji_thinking(), + tool: emoji_tool(), + coding: emoji_coding(), + web: emoji_web(), + done: emoji_done(), + error: emoji_error(), } } } @@ -441,8 +514,10 @@ impl Default for ReactionEmojis { impl Default for ReactionTiming { fn default() -> Self { Self { - debounce_ms: default_debounce_ms(), stall_soft_ms: default_stall_soft_ms(), - stall_hard_ms: default_stall_hard_ms(), done_hold_ms: default_done_hold_ms(), + debounce_ms: default_debounce_ms(), + stall_soft_ms: default_stall_soft_ms(), + stall_hard_ms: default_stall_hard_ms(), + done_hold_ms: default_done_hold_ms(), error_hold_ms: default_error_hold_ms(), } } @@ -516,16 +591,31 @@ fn parse_config(raw: &str, source: &str) -> anyhow::Result { // and max_batch_tokens > 0 (otherwise the consumer's token-cap check forces every // batch to size 1 — functionally per-message via a confusing path). if let Some(ref d) = config.discord { - anyhow::ensure!(d.max_buffered_messages > 0, "discord.max_buffered_messages must be > 0"); - anyhow::ensure!(d.max_batch_tokens > 0, "discord.max_batch_tokens must be > 0"); + anyhow::ensure!( + d.max_buffered_messages > 0, + "discord.max_buffered_messages must be > 0" + ); + anyhow::ensure!( + d.max_batch_tokens > 0, + "discord.max_batch_tokens must be > 0" + ); } if let Some(ref s) = config.slack { - anyhow::ensure!(s.max_buffered_messages > 0, "slack.max_buffered_messages must be > 0"); + anyhow::ensure!( + s.max_buffered_messages > 0, + "slack.max_buffered_messages must be > 0" + ); anyhow::ensure!(s.max_batch_tokens > 0, "slack.max_batch_tokens must be > 0"); } if let Some(ref g) = config.gateway { - anyhow::ensure!(g.max_buffered_messages > 0, "gateway.max_buffered_messages must be > 0"); - anyhow::ensure!(g.max_batch_tokens > 0, "gateway.max_batch_tokens must be > 0"); + anyhow::ensure!( + g.max_buffered_messages > 0, + "gateway.max_buffered_messages must be > 0" + ); + anyhow::ensure!( + g.max_batch_tokens > 0, + "gateway.max_batch_tokens must be > 0" + ); } Ok(config) @@ -586,7 +676,10 @@ command = "echo" fn parse_invalid_toml_returns_error() { let result = parse_config("not valid toml {{{}}", "test"); assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("failed to parse config from test")); + assert!(result + .unwrap_err() + .to_string() + .contains("failed to parse config from test")); } #[test] @@ -608,7 +701,10 @@ command = "echo" async fn load_config_from_url_invalid_host() { let result = load_config_from_url("https://invalid.test.example/config.toml").await; assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("failed to fetch remote config")); + assert!(result + .unwrap_err() + .to_string() + .contains("failed to fetch remote config")); } #[test] @@ -630,7 +726,10 @@ command = "echo" assert!(gw.allow_all_channels.is_none()); // resolve_allow_all: empty lists → allow all assert!(resolve_allow_all(gw.allow_all_users, &gw.allowed_users)); - assert!(resolve_allow_all(gw.allow_all_channels, &gw.allowed_channels)); + assert!(resolve_allow_all( + gw.allow_all_channels, + &gw.allowed_channels + )); } #[test] @@ -652,7 +751,10 @@ command = "echo" assert_eq!(gw.allowed_channels, vec!["C1"]); // resolve_allow_all: non-empty lists → restricted assert!(!resolve_allow_all(gw.allow_all_users, &gw.allowed_users)); - assert!(!resolve_allow_all(gw.allow_all_channels, &gw.allowed_channels)); + assert!(!resolve_allow_all( + gw.allow_all_channels, + &gw.allowed_channels + )); } #[test] @@ -764,4 +866,29 @@ command = "echo" // explicit flag overrides non-empty list assert!(resolve_allow_all(gw.allow_all_users, &gw.allowed_users)); } + + #[test] + fn stt_echo_transcript_defaults_to_false() { + let cfg = SttConfig::default(); + assert!( + !cfg.echo_transcript, + "echo_transcript should default to false" + ); + } + + #[test] + fn stt_echo_transcript_respects_explicit_false() { + let toml = r#" +[agent] +command = "echo" + +[stt] +enabled = true +api_key = "test" +echo_transcript = false +"#; + let cfg = parse_config(toml, "test").unwrap(); + assert!(cfg.stt.enabled); + assert!(!cfg.stt.echo_transcript); + } } diff --git a/src/cron.rs b/src/cron.rs index 1aec1621..9f39f9b5 100644 --- a/src/cron.rs +++ b/src/cron.rs @@ -1,4 +1,4 @@ -use crate::adapter::{AdapterRouter, ChatAdapter, ChannelRef, SenderContext}; +use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, SenderContext}; use crate::config::CronJobConfig; use crate::format; use chrono::{Timelike, Utc}; @@ -24,9 +24,7 @@ pub fn parse_cron_expr(expr: &str) -> Result { /// schedule has an event at exactly that minute. pub fn should_fire(schedule: &Schedule, tz: Tz) -> bool { let now = Utc::now().with_timezone(&tz); - let minute_start = now - .with_second(0).unwrap() - .with_nanosecond(0).unwrap(); + let minute_start = now.with_second(0).unwrap().with_nanosecond(0).unwrap(); let query_from = minute_start - chrono::Duration::seconds(1); schedule .after(&query_from) @@ -39,20 +37,35 @@ pub fn should_fire(schedule: &Schedule, tz: Tz) -> bool { const VALID_PLATFORMS: &[&str] = &["discord", "slack"]; /// Validate all cronjob configs (fail-fast on bad cron expressions or timezones). -pub fn validate_cronjobs(cronjobs: &[CronJobConfig], configured_platforms: &[&str]) -> anyhow::Result<()> { +pub fn validate_cronjobs( + cronjobs: &[CronJobConfig], + configured_platforms: &[&str], +) -> anyhow::Result<()> { for (i, job) in cronjobs.iter().enumerate() { - if !job.enabled { continue; } + if !job.enabled { + continue; + } parse_cron_expr(&job.schedule).map_err(|e| { - anyhow::anyhow!("cronjobs[{i}]: invalid cron expression {:?}: {e}", job.schedule) + anyhow::anyhow!( + "cronjobs[{i}]: invalid cron expression {:?}: {e}", + job.schedule + ) })?; job.timezone.parse::().map_err(|e| { anyhow::anyhow!("cronjobs[{i}]: invalid timezone {:?}: {e}", job.timezone) })?; if !VALID_PLATFORMS.contains(&job.platform.as_str()) { - anyhow::bail!("cronjobs[{i}]: unknown platform {:?} (expected one of: {VALID_PLATFORMS:?})", job.platform); + anyhow::bail!( + "cronjobs[{i}]: unknown platform {:?} (expected one of: {VALID_PLATFORMS:?})", + job.platform + ); } if !configured_platforms.contains(&job.platform.as_str()) { - anyhow::bail!("cronjobs[{i}]: platform {:?} is not configured — add [{}] to config.toml", job.platform, job.platform); + anyhow::bail!( + "cronjobs[{i}]: platform {:?} is not configured — add [{}] to config.toml", + job.platform, + job.platform + ); } } Ok(()) @@ -183,7 +196,9 @@ pub async fn run_scheduler( if baseline_jobs.is_empty() && usercron_jobs.is_empty() { if usercron_path.is_some() { - info!("no cronjobs yet, but usercron_path is set — scheduler will watch for cronjob.toml"); + info!( + "no cronjobs yet, but usercron_path is set — scheduler will watch for cronjob.toml" + ); } else { debug!("no cronjobs configured, scheduler not started"); return; @@ -191,14 +206,23 @@ pub async fn run_scheduler( } let total = baseline_jobs.len() + usercron_jobs.len(); - info!(baseline = baseline_jobs.len(), usercron = usercron_jobs.len(), total, "cron scheduler started"); + info!( + baseline = baseline_jobs.len(), + usercron = usercron_jobs.len(), + total, + "cron scheduler started" + ); let in_flight: Arc>> = Arc::new(Mutex::new(HashSet::new())); // Align to next minute boundary let now = Utc::now(); let secs_into_minute = now.timestamp() % 60; - let align_delay = if secs_into_minute == 0 { 0 } else { 60 - secs_into_minute as u64 }; + let align_delay = if secs_into_minute == 0 { + 0 + } else { + 60 - secs_into_minute as u64 + }; if align_delay > 0 { debug!(align_secs = align_delay, "aligning to next minute boundary"); tokio::time::sleep(std::time::Duration::from_secs(align_delay)).await; @@ -301,7 +325,10 @@ async fn fire_cronjob( adapters: &HashMap>, in_flight: Arc>>, ) { - let _guard = InFlightGuard { idx, set: in_flight }; + let _guard = InFlightGuard { + idx, + set: in_flight, + }; let adapter = match adapters.get(&job.platform) { Some(a) => a.clone(), @@ -319,7 +346,13 @@ async fn fire_cronjob( origin_event_id: None, }; - let trigger_msg = match adapter.send_message(&thread_channel, &format!("🕐 [{}]: {}", job.sender_name, job.message)).await { + let trigger_msg = match adapter + .send_message( + &thread_channel, + &format!("🕐 [{}]: {}", job.sender_name, job.message), + ) + .await + { Ok(msg) => msg, Err(e) => { error!(channel = %job.channel, error = %e, "failed to send cron message"); @@ -331,11 +364,19 @@ async fn fire_cronjob( thread_channel.clone() } else { let thread_name = format::shorten_thread_name(&job.message); - match adapter.create_thread(&thread_channel, &trigger_msg, &thread_name).await { + match adapter + .create_thread(&thread_channel, &trigger_msg, &thread_name) + .await + { Ok(ch) => ch, Err(e) => { error!(channel = %job.channel, error = %e, "failed to create cron thread"); - let _ = adapter.send_message(&thread_channel, &format!("⚠️ cronjob: failed to create thread: {e}")).await; + let _ = adapter + .send_message( + &thread_channel, + &format!("⚠️ cronjob: failed to create thread: {e}"), + ) + .await; return; } } @@ -347,8 +388,15 @@ async fn fire_cronjob( sender_name: job.sender_name.clone(), display_name: job.sender_name.clone(), channel: job.platform.clone(), - channel_id: reply_channel.parent_id.as_deref().unwrap_or(&reply_channel.channel_id).to_string(), - thread_id: reply_channel.thread_id.clone().or(Some(reply_channel.channel_id.clone())), + channel_id: reply_channel + .parent_id + .as_deref() + .unwrap_or(&reply_channel.channel_id) + .to_string(), + thread_id: reply_channel + .thread_id + .clone() + .or(Some(reply_channel.channel_id.clone())), is_bot: true, timestamp: Some(Utc::now().to_rfc3339()), }; @@ -361,18 +409,23 @@ async fn fire_cronjob( }; if let Err(e) = router - .handle_message(&adapter, crate::adapter::MessageContext { - thread_channel: reply_channel.clone(), - sender_json, - prompt: job.message.clone(), - extra_blocks: vec![], - trigger_msg, - other_bot_present: false, - }) + .handle_message( + &adapter, + crate::adapter::MessageContext { + thread_channel: reply_channel.clone(), + sender_json, + prompt: job.message.clone(), + extra_blocks: vec![], + trigger_msg, + other_bot_present: false, + }, + ) .await { error!("cron handle_message error: {e}"); - let _ = adapter.send_message(&reply_channel, &format!("⚠️ cronjob error: {e}")).await; + let _ = adapter + .send_message(&reply_channel, &format!("⚠️ cronjob error: {e}")) + .await; } } @@ -502,12 +555,16 @@ thread_id = "789" fn load_usercron_valid_file() { let dir = tempfile::tempdir().unwrap(); let path = dir.path().join("cronjob.toml"); - std::fs::write(&path, r#" + std::fs::write( + &path, + r#" [[jobs]] schedule = "* * * * *" channel = "123" message = "ping" -"#).unwrap(); +"#, + ) + .unwrap(); let jobs = load_usercron_file(&path, &["discord"]); assert_eq!(jobs.len(), 1); assert_eq!(jobs[0].message, "ping"); @@ -526,7 +583,9 @@ message = "ping" fn load_usercron_skips_invalid_entries() { let dir = tempfile::tempdir().unwrap(); let path = dir.path().join("cronjob.toml"); - std::fs::write(&path, r#" + std::fs::write( + &path, + r#" [[jobs]] schedule = "* * * * *" channel = "123" @@ -536,7 +595,9 @@ message = "good" schedule = "bad cron" channel = "456" message = "bad" -"#).unwrap(); +"#, + ) + .unwrap(); let jobs = load_usercron_file(&path, &["discord"]); assert_eq!(jobs.len(), 1); assert_eq!(jobs[0].message, "good"); @@ -546,7 +607,9 @@ message = "bad" fn load_usercron_skips_unconfigured_platform() { let dir = tempfile::tempdir().unwrap(); let path = dir.path().join("cronjob.toml"); - std::fs::write(&path, r#" + std::fs::write( + &path, + r#" [[jobs]] schedule = "* * * * *" channel = "123" @@ -557,7 +620,9 @@ schedule = "* * * * *" channel = "456" message = "slack job" platform = "slack" -"#).unwrap(); +"#, + ) + .unwrap(); // Only discord configured let jobs = load_usercron_file(&path, &["discord"]); assert_eq!(jobs.len(), 1); @@ -569,9 +634,14 @@ platform = "slack" #[test] fn validate_cronjobs_valid_passes() { let jobs = vec![CronJobConfig { - enabled: true, schedule: "0 9 * * 1-5".into(), channel: "123".into(), - message: "hi".into(), platform: "discord".into(), sender_name: "test".into(), - thread_id: None, timezone: "UTC".into(), + enabled: true, + schedule: "0 9 * * 1-5".into(), + channel: "123".into(), + message: "hi".into(), + platform: "discord".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), }]; assert!(validate_cronjobs(&jobs, &["discord"]).is_ok()); } @@ -579,9 +649,14 @@ platform = "slack" #[test] fn validate_cronjobs_invalid_cron_fails() { let jobs = vec![CronJobConfig { - enabled: true, schedule: "bad".into(), channel: "123".into(), - message: "hi".into(), platform: "discord".into(), sender_name: "test".into(), - thread_id: None, timezone: "UTC".into(), + enabled: true, + schedule: "bad".into(), + channel: "123".into(), + message: "hi".into(), + platform: "discord".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), }]; let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); assert!(err.to_string().contains("invalid cron expression")); @@ -590,9 +665,14 @@ platform = "slack" #[test] fn validate_cronjobs_invalid_timezone_fails() { let jobs = vec![CronJobConfig { - enabled: true, schedule: "* * * * *".into(), channel: "123".into(), - message: "hi".into(), platform: "discord".into(), sender_name: "test".into(), - thread_id: None, timezone: "Mars/Olympus".into(), + enabled: true, + schedule: "* * * * *".into(), + channel: "123".into(), + message: "hi".into(), + platform: "discord".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "Mars/Olympus".into(), }]; let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); assert!(err.to_string().contains("invalid timezone")); @@ -601,9 +681,14 @@ platform = "slack" #[test] fn validate_cronjobs_unknown_platform_fails() { let jobs = vec![CronJobConfig { - enabled: true, schedule: "* * * * *".into(), channel: "123".into(), - message: "hi".into(), platform: "telegram".into(), sender_name: "test".into(), - thread_id: None, timezone: "UTC".into(), + enabled: true, + schedule: "* * * * *".into(), + channel: "123".into(), + message: "hi".into(), + platform: "telegram".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), }]; let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); assert!(err.to_string().contains("unknown platform")); @@ -612,9 +697,14 @@ platform = "slack" #[test] fn validate_cronjobs_unconfigured_platform_fails() { let jobs = vec![CronJobConfig { - enabled: true, schedule: "* * * * *".into(), channel: "123".into(), - message: "hi".into(), platform: "slack".into(), sender_name: "test".into(), - thread_id: None, timezone: "UTC".into(), + enabled: true, + schedule: "* * * * *".into(), + channel: "123".into(), + message: "hi".into(), + platform: "slack".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), }]; let err = validate_cronjobs(&jobs, &["discord"]).unwrap_err(); assert!(err.to_string().contains("not configured")); @@ -623,9 +713,14 @@ platform = "slack" #[test] fn validate_cronjobs_disabled_with_invalid_cron_passes() { let jobs = vec![CronJobConfig { - enabled: false, schedule: "bad".into(), channel: "123".into(), - message: "hi".into(), platform: "discord".into(), sender_name: "test".into(), - thread_id: None, timezone: "UTC".into(), + enabled: false, + schedule: "bad".into(), + channel: "123".into(), + message: "hi".into(), + platform: "discord".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), }]; assert!(validate_cronjobs(&jobs, &["discord"]).is_ok()); } @@ -633,9 +728,14 @@ platform = "slack" #[test] fn validate_cronjobs_enabled_with_invalid_cron_still_fails() { let jobs = vec![CronJobConfig { - enabled: true, schedule: "bad".into(), channel: "123".into(), - message: "hi".into(), platform: "discord".into(), sender_name: "test".into(), - thread_id: None, timezone: "UTC".into(), + enabled: true, + schedule: "bad".into(), + channel: "123".into(), + message: "hi".into(), + platform: "discord".into(), + sender_name: "test".into(), + thread_id: None, + timezone: "UTC".into(), }]; assert!(validate_cronjobs(&jobs, &["discord"]).is_err()); } diff --git a/src/discord.rs b/src/discord.rs index 13987dea..a8b27be2 100644 --- a/src/discord.rs +++ b/src/discord.rs @@ -1,23 +1,27 @@ -use crate::acp::ContentBlock; use crate::acp::protocol::ConfigOption; -use crate::adapter::{AdapterRouter, ChatAdapter, ChannelRef, MessageRef, SenderContext}; +use crate::acp::ContentBlock; +use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, MessageRef, SenderContext}; use crate::bot_turns::{BotTurnTracker, TurnAction, TurnSeverity}; use crate::config::{AllowBots, AllowUsers, SttConfig}; use crate::format; use crate::media; use async_trait::async_trait; -use std::sync::LazyLock; -use serenity::builder::{CreateActionRow, CreateButton, CreateCommand, CreateInteractionResponse, CreateInteractionResponseMessage, CreateSelectMenu, CreateSelectMenuKind, CreateSelectMenuOption, CreateThread, EditMessage}; -use serenity::model::application::ButtonStyle; +use serenity::builder::{ + CreateActionRow, CreateButton, CreateCommand, CreateInteractionResponse, + CreateInteractionResponseMessage, CreateSelectMenu, CreateSelectMenuKind, + CreateSelectMenuOption, CreateThread, EditMessage, +}; use serenity::http::Http; +use serenity::model::application::ButtonStyle; use serenity::model::application::{Command, ComponentInteractionDataKind, Interaction}; use serenity::model::channel::{AutoArchiveDuration, Message, MessageType, ReactionType}; use serenity::model::gateway::Ready; use serenity::model::id::{ChannelId, MessageId, UserId}; use serenity::prelude::*; use std::collections::{HashMap, HashSet}; +use std::sync::LazyLock; use std::sync::{Arc, OnceLock}; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; /// Hard cap on consecutive bot messages in a channel or thread. /// Prevents runaway loops between multiple bots in "all" mode. @@ -57,7 +61,11 @@ impl ChatAdapter for DiscordAdapter { 2000 } - async fn send_message(&self, channel: &ChannelRef, content: &str) -> anyhow::Result { + async fn send_message( + &self, + channel: &ChannelRef, + content: &str, + ) -> anyhow::Result { let ch_id: u64 = Self::resolve_channel(channel).parse()?; let msg = ChannelId::new(ch_id).say(&self.http, content).await?; Ok(MessageRef { @@ -181,11 +189,15 @@ impl Handler { // Check positive caches let cached_involved = { let cache = self.participated_threads.lock().await; - cache.get(&key).is_some_and(|ts| ts.elapsed() < self.session_ttl) + cache + .get(&key) + .is_some_and(|ts| ts.elapsed() < self.session_ttl) }; let cached_multibot = { let cache = self.multibot_threads.lock().await; - cache.get(&key).is_some_and(|ts| ts.elapsed() < self.session_ttl) + cache + .get(&key) + .is_some_and(|ts| ts.elapsed() < self.session_ttl) }; // Both cached → skip fetch entirely @@ -212,7 +224,10 @@ impl Handler { }; let involved = cached_involved || messages.iter().any(|m| m.author.id == bot_id); - let other_bot_present = cached_multibot || messages.iter().any(|m| m.author.bot && m.author.id != bot_id); + let other_bot_present = cached_multibot + || messages + .iter() + .any(|m| m.author.bot && m.author.id != bot_id); if involved && !cached_involved { let mut cache = self.participated_threads.lock().await; @@ -277,7 +292,11 @@ impl EventHandler for Handler { match tracker.classify_bot_message(&thread_key) { TurnAction::Continue => {} TurnAction::SilentStop => return, - TurnAction::WarnAndStop { severity, turns, user_message } => { + TurnAction::WarnAndStop { + severity, + turns, + user_message, + } => { match severity { TurnSeverity::Hard => tracing::warn!( channel_id = %msg.channel_id, @@ -350,28 +369,36 @@ impl EventHandler for Handler { return; } - let adapter = self.adapter.get_or_init(|| { - Arc::new(DiscordAdapter::new(ctx.http.clone())) - }).clone(); + let adapter = self + .adapter + .get_or_init(|| Arc::new(DiscordAdapter::new(ctx.http.clone()))) + .clone(); let channel_id = msg.channel_id.get(); let in_allowed_channel = self.allow_all_channels || self.allowed_channels.contains(&channel_id); - let is_mentioned = msg.mentions_user_id(bot_id) - || msg.content.contains(&format!("<@{}>", bot_id)); + let is_mentioned = + msg.mentions_user_id(bot_id) || msg.content.contains(&format!("<@{}>", bot_id)); // Bot message gating (from upstream #321) if msg.author.bot { match self.allow_bot_messages { AllowBots::Off => return, - AllowBots::Mentions => if !is_mentioned { return; }, + AllowBots::Mentions => { + if !is_mentioned { + return; + } + } AllowBots::All => { let cap = MAX_CONSECUTIVE_BOT_TURNS as usize; let limit = std::cmp::min(MAX_CONSECUTIVE_BOT_TURNS, 100) as u8; - let history = ctx.cache.channel_messages(msg.channel_id) + let history = ctx + .cache + .channel_messages(msg.channel_id) .map(|msgs| { - let mut recent: Vec<_> = msgs.iter() + let mut recent: Vec<_> = msgs + .iter() .filter(|(mid, _)| **mid < msg.id) .map(|(_, m)| m.clone()) .collect(); @@ -384,8 +411,14 @@ impl EventHandler for Handler { let recent = if let Some(cached) = history { cached } else { - match msg.channel_id - .messages(&ctx.http, serenity::builder::GetMessages::new().before(msg.id).limit(limit)) + match msg + .channel_id + .messages( + &ctx.http, + serenity::builder::GetMessages::new() + .before(msg.id) + .limit(limit), + ) .await { Ok(msgs) => msgs, @@ -396,17 +429,20 @@ impl EventHandler for Handler { } }; - let consecutive_bot = recent.iter() + let consecutive_bot = recent + .iter() .take_while(|m| m.author.bot && m.author.id != bot_id) .count(); if consecutive_bot >= cap { tracing::warn!(channel_id = %msg.channel_id, cap, "bot turn cap reached, ignoring"); return; } - }, + } } - if !self.trusted_bot_ids.is_empty() && !self.trusted_bot_ids.contains(&msg.author.id.get()) { + if !self.trusted_bot_ids.is_empty() + && !self.trusted_bot_ids.contains(&msg.author.id.get()) + { tracing::debug!(bot_id = %msg.author.id, "bot not in trusted_bot_ids, ignoring"); return; } @@ -415,7 +451,11 @@ impl EventHandler for Handler { // Thread detection: single to_channel() call for both allowed and // non-allowed channels. Uses thread_metadata (not parent_id) to // identify threads — see detect_thread() doc comments for rationale. - let (in_thread, bot_owns_thread, thread_parent_id, is_dm) = match msg.channel_id.to_channel(&ctx.http).await { + let (in_thread, bot_owns_thread, thread_parent_id, is_dm) = match msg + .channel_id + .to_channel(&ctx.http) + .await + { Ok(serenity::model::channel::Channel::Guild(gc)) => { let parent = gc.parent_id.map(|id| id.get().to_string()); let result = detect_thread( @@ -436,7 +476,12 @@ impl EventHandler for Handler { bot_owns = ?result.1, "thread check" ); - (result.0, result.1.unwrap_or(false), if result.0 { parent } else { None }, false) + ( + result.0, + result.1.unwrap_or(false), + if result.0 { parent } else { None }, + false, + ) } Ok(serenity::model::channel::Channel::Private(_)) => { tracing::debug!(channel_id = %msg.channel_id, "DM channel"); @@ -513,7 +558,12 @@ impl EventHandler for Handler { } } - if is_denied_user(msg.author.bot, self.allow_all_users, &self.allowed_users, msg.author.id.get()) { + if is_denied_user( + msg.author.bot, + self.allow_all_users, + &self.allowed_users, + msg.author.id.get(), + ) { tracing::info!(user_id = %msg.author.id, "denied user, ignoring"); let msg_ref = discord_msg_ref(&msg); let _ = adapter.add_reaction(&msg_ref, "🚫").await; @@ -544,6 +594,7 @@ impl EventHandler for Handler { // Build extra content blocks from attachments (audio → STT, text → inline, image → encode) let mut extra_blocks = Vec::new(); + let mut echo_entries: Vec = Vec::new(); let mut text_file_bytes: u64 = 0; let mut text_file_count: u32 = 0; const TEXT_TOTAL_CAP: u64 = 1024 * 1024; // 1 MB total for all text file attachments @@ -554,25 +605,38 @@ impl EventHandler for Handler { if media::is_audio_mime(mime) { if self.stt_config.enabled { let mime_clean = mime.split(';').next().unwrap_or(mime).trim(); - if let Some(transcript) = media::download_and_transcribe( + match media::download_and_transcribe( &attachment.url, &attachment.filename, mime_clean, u64::from(attachment.size), &self.stt_config, None, - ).await { - debug!(filename = %attachment.filename, chars = transcript.len(), "voice transcript injected"); - extra_blocks.insert(0, ContentBlock::Text { - text: format!("[Voice message transcript]: {transcript}"), - }); + ) + .await + { + Some(transcript) => { + debug!(filename = %attachment.filename, chars = transcript.len(), "voice transcript injected"); + extra_blocks.insert( + 0, + ContentBlock::Text { + text: format!("[Voice message transcript]: {transcript}"), + }, + ); + echo_entries.push(crate::stt::EchoEntry::Success(transcript)); + } + None => { + warn!(filename = %attachment.filename, "STT failed for voice attachment"); + echo_entries.push(crate::stt::EchoEntry::Failed); + } } } else { tracing::warn!(filename = %attachment.filename, "skipping audio attachment (STT disabled)"); let msg_ref = discord_msg_ref(&msg); let _ = adapter.add_reaction(&msg_ref, "🎤").await; } - } else if media::is_text_file(&attachment.filename, attachment.content_type.as_deref()) { + } else if media::is_text_file(&attachment.filename, attachment.content_type.as_deref()) + { if text_file_count >= TEXT_FILE_COUNT_CAP { tracing::warn!(filename = %attachment.filename, count = text_file_count, "text file count cap reached, skipping"); continue; @@ -588,7 +652,9 @@ impl EventHandler for Handler { &attachment.filename, u64::from(attachment.size), None, - ).await { + ) + .await + { text_file_bytes += actual_bytes; text_file_count += 1; debug!(filename = %attachment.filename, "adding text file attachment"); @@ -600,7 +666,9 @@ impl EventHandler for Handler { &attachment.filename, u64::from(attachment.size), None, - ).await { + ) + .await + { debug!(url = %attachment.url, filename = %attachment.filename, "adding image attachment"); extra_blocks.push(block); } @@ -649,15 +717,24 @@ impl EventHandler for Handler { } let dispatcher = self.dispatcher.clone(); + let stt_cfg = self.stt_config.clone(); tokio::spawn(async move { + // Best-effort echo before the agent reply so the user can verify STT. + crate::stt::post_echo( + &adapter, + &thread_channel, + &trigger_msg, + &echo_entries, + &stt_cfg, + ) + .await; + let sender_id = sender.sender_id.clone(); let sender_name = sender.sender_name.clone(); let sender_json = serde_json::to_string(&sender).unwrap(); - let thread_key = - dispatcher.key("discord", &thread_channel.channel_id, &sender_id); - let estimated_tokens = - crate::dispatch::estimate_tokens(&prompt, &extra_blocks); + let thread_key = dispatcher.key("discord", &thread_channel.channel_id, &sender_id); + let estimated_tokens = crate::dispatch::estimate_tokens(&prompt, &extra_blocks); let buf_msg = crate::dispatch::BufferedMessage { sender_json, sender_name, @@ -682,16 +759,12 @@ impl EventHandler for Handler { // Build the shared command list once. let commands = vec![ - CreateCommand::new("models") - .description("Select the AI model for this session"), - CreateCommand::new("agents") - .description("Select the agent mode for this session"), - CreateCommand::new("cancel") - .description("Cancel the current operation"), + CreateCommand::new("models").description("Select the AI model for this session"), + CreateCommand::new("agents").description("Select the agent mode for this session"), + CreateCommand::new("cancel").description("Cancel the current operation"), CreateCommand::new("cancel-all") .description("Cancel current operation and drop all buffered messages"), - CreateCommand::new("reset") - .description("Reset the conversation session"), + CreateCommand::new("reset").description("Reset the conversation session"), ]; // Register global commands (works in DMs + all guilds after propagation). @@ -704,10 +777,7 @@ impl EventHandler for Handler { // Also register per-guild for instant availability (global can take up to 1h). for guild in &ready.guilds { let guild_id = guild.id; - if let Err(e) = guild_id - .set_commands(&ctx.http, commands.clone()) - .await - { + if let Err(e) = guild_id.set_commands(&ctx.http, commands.clone()).await { tracing::warn!(%guild_id, error = %e, "failed to register guild slash commands"); } else { info!(%guild_id, "registered guild slash commands"); @@ -718,10 +788,12 @@ impl EventHandler for Handler { async fn interaction_create(&self, ctx: Context, interaction: Interaction) { match interaction { Interaction::Command(cmd) if cmd.data.name == "models" => { - self.handle_config_command(&ctx, &cmd, "model", "model").await; + self.handle_config_command(&ctx, &cmd, "model", "model") + .await; } Interaction::Command(cmd) if cmd.data.name == "agents" => { - self.handle_config_command(&ctx, &cmd, "agent", "agent").await; + self.handle_config_command(&ctx, &cmd, "agent", "agent") + .await; } Interaction::Command(cmd) if cmd.data.name == "cancel" => { self.handle_cancel_command(&ctx, &cmd).await; @@ -743,19 +815,26 @@ impl EventHandler for Handler { } } - // --- Slash command & interaction handlers --- impl Handler { /// Build a Discord select menu from ACP configOptions with the given category. /// Paginates options in pages of 25 (Discord limit). The current selection is /// always placed first so it appears on page 0. - fn build_config_select(options: &[ConfigOption], category: &str, page: usize) -> Option { - let opt = options.iter().find(|o| o.category.as_deref() == Some(category))?; + fn build_config_select( + options: &[ConfigOption], + category: &str, + page: usize, + ) -> Option { + let opt = options + .iter() + .find(|o| o.category.as_deref() == Some(category))?; // Put current selection first so it always lands on page 0, // then fill remaining slots in original order. - let sorted: Vec<_> = opt.options.iter() + let sorted: Vec<_> = opt + .options + .iter() .filter(|o| o.value == opt.current_value) .chain(opt.options.iter().filter(|o| o.value != opt.current_value)) .collect(); @@ -780,13 +859,20 @@ impl Handler { return None; } - let current_name = opt.options.iter() + let current_name = opt + .options + .iter() .find(|o| o.value == opt.current_value) .map(|o| o.name.as_str()) .unwrap_or(&opt.current_value); let total_pages = sorted.len().div_ceil(SELECT_MENU_PAGE_SIZE); let placeholder = if total_pages > 1 { - format!("Current: {} (page {}/{})", current_name, page + 1, total_pages) + format!( + "Current: {} (page {}/{})", + current_name, + page + 1, + total_pages + ) } else { format!("Current: {}", current_name) }; @@ -794,14 +880,20 @@ impl Handler { Some( CreateSelectMenu::new( format!("acp_config_{}", opt.id), - CreateSelectMenuKind::String { options: menu_options }, + CreateSelectMenuKind::String { + options: menu_options, + }, ) - .placeholder(placeholder) + .placeholder(placeholder), ) } /// Build ◀/▶ pagination buttons. Returns None when only one page exists. - fn build_pagination_buttons(category: &str, page: usize, total_pages: usize) -> Option { + fn build_pagination_buttons( + category: &str, + page: usize, + total_pages: usize, + ) -> Option { if total_pages <= 1 { return None; } @@ -822,12 +914,20 @@ impl Handler { /// Build the full component rows (select menu + optional pagination) for a config category. /// When `page` is `None`, auto-selects the page containing the current value. - fn build_config_components(options: &[ConfigOption], category: &str, page: Option) -> Option> { - let opt = options.iter().find(|o| o.category.as_deref() == Some(category))?; + fn build_config_components( + options: &[ConfigOption], + category: &str, + page: Option, + ) -> Option> { + let opt = options + .iter() + .find(|o| o.category.as_deref() == Some(category))?; let total_pages = opt.options.len().div_ceil(SELECT_MENU_PAGE_SIZE); let page = match page { Some(p) => p.min(total_pages.saturating_sub(1)), - None => opt.options.iter() + None => opt + .options + .iter() .position(|o| o.value == opt.current_value) .map(|i| i / SELECT_MENU_PAGE_SIZE) .unwrap_or(0), @@ -884,7 +984,9 @@ impl Handler { }; let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new().content(msg).ephemeral(true), + CreateInteractionResponseMessage::new() + .content(msg) + .ephemeral(true), ); if let Err(e) = cmd.create_response(&ctx.http, response).await { tracing::error!(error = %e, "failed to respond to /cancel command"); @@ -910,12 +1012,16 @@ impl Handler { let msg = match (cancel_result, dropped) { (Ok(()), 0) => "🛑 Cancel signal sent.".to_string(), (Ok(()), _) => "🛑 Cancel signal sent. Buffered messages cleared.".to_string(), - (Err(_), 0) => "⚠️ Nothing to cancel — no active session and no buffered messages.".to_string(), + (Err(_), 0) => { + "⚠️ Nothing to cancel — no active session and no buffered messages.".to_string() + } (Err(_), _) => "🛑 Buffered messages cleared. No active session to cancel.".to_string(), }; let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new().content(msg).ephemeral(true), + CreateInteractionResponseMessage::new() + .content(msg) + .ephemeral(true), ); if let Err(e) = cmd.create_response(&ctx.http, response).await { tracing::error!(error = %e, "failed to respond to /cancel-all command"); @@ -944,11 +1050,16 @@ impl Handler { Err(_) if dropped > 0 => { format!("🔄 Dropped {dropped} buffered message(s). No active session to reset.") } - Err(_) => "⚠️ No active session to reset. Start a conversation first by @mentioning the bot.".to_string(), + Err(_) => { + "⚠️ No active session to reset. Start a conversation first by @mentioning the bot." + .to_string() + } }; let response = CreateInteractionResponse::Message( - CreateInteractionResponseMessage::new().content(msg).ephemeral(true), + CreateInteractionResponseMessage::new() + .content(msg) + .ephemeral(true), ); if let Err(e) = cmd.create_response(&ctx.http, response).await { tracing::error!(error = %e, "failed to respond to /reset command"); @@ -972,12 +1083,10 @@ impl Handler { } let selected_value = match &comp.data.kind { - ComponentInteractionDataKind::StringSelect { values } => { - match values.first() { - Some(v) => v.clone(), - None => return, - } - } + ComponentInteractionDataKind::StringSelect { values } => match values.first() { + Some(v) => v.clone(), + None => return, + }, _ => return, }; @@ -1006,7 +1115,9 @@ impl Handler { }; let response = CreateInteractionResponse::UpdateMessage( - CreateInteractionResponseMessage::new().content(response_msg).components(vec![]), + CreateInteractionResponseMessage::new() + .content(response_msg) + .components(vec![]), ); if let Err(e) = comp.create_response(&ctx.http, response).await { @@ -1100,7 +1211,10 @@ async fn get_or_create_thread( origin_event_id: None, }; let trigger_ref = discord_msg_ref(msg); - match adapter.create_thread(&parent, &trigger_ref, &thread_name).await { + match adapter + .create_thread(&parent, &trigger_ref, &thread_name) + .await + { Ok(ch) => Ok(ch), Err(e) if is_thread_already_exists_error(&e) => { // Another bot won the race from the same trigger message. Discord @@ -1110,9 +1224,9 @@ async fn get_or_create_thread( .channel_id .message(&ctx.http, msg.id) .await - .map_err(|fe| anyhow::anyhow!( - "thread_already_exists (race), but refetch failed: {fe}" - ))?; + .map_err(|fe| { + anyhow::anyhow!("thread_already_exists (race), but refetch failed: {fe}") + })?; let existing = refreshed.thread.ok_or_else(|| { anyhow::anyhow!( "thread_already_exists (race), but message has no thread after refetch" @@ -1147,9 +1261,8 @@ fn is_thread_already_exists_error(err: &anyhow::Error) -> bool { msg.contains("160004") || msg.contains("already been created") } -static ROLE_MENTION_RE: LazyLock = LazyLock::new(|| { - regex::Regex::new(r"<@&\d+>").unwrap() -}); +static ROLE_MENTION_RE: LazyLock = + LazyLock::new(|| regex::Regex::new(r"<@&\d+>").unwrap()); fn resolve_mentions(content: &str, bot_id: UserId) -> String { // 1. Strip the bot's own trigger mention @@ -1236,7 +1349,12 @@ fn detect_thread( /// Returns `true` if the author should be denied by the user allowlist. /// Bot authors skip this check — they are gated by `allow_bot_messages` + `trusted_bot_ids`. -fn is_denied_user(is_bot: bool, allow_all_users: bool, allowed_users: &HashSet, user_id: u64) -> bool { +fn is_denied_user( + is_bot: bool, + allow_all_users: bool, + allowed_users: &HashSet, + user_id: u64, +) -> bool { !is_bot && !allow_all_users && !allowed_users.contains(&user_id) } @@ -1290,7 +1408,7 @@ fn should_process_user_message( #[cfg(test)] mod tests { use super::*; - use crate::bot_turns::{HARD_BOT_TURN_LIMIT, TurnResult}; + use crate::bot_turns::{TurnResult, HARD_BOT_TURN_LIMIT}; // --- resolve_mentions tests --- @@ -1376,10 +1494,10 @@ mod tests { fn multibot_mentions_single_bot_thread_no_mention() { assert!(should_process_user_message( AllowUsers::MultibotMentions, - false, // is_mentioned - true, // in_thread - true, // involved - false, // other_bot_present + false, // is_mentioned + true, // in_thread + true, // involved + false, // other_bot_present )); } @@ -1391,10 +1509,10 @@ mod tests { fn multibot_mentions_multi_bot_thread_no_mention() { assert!(!should_process_user_message( AllowUsers::MultibotMentions, - false, // is_mentioned - true, // in_thread - true, // involved - true, // other_bot_present ← another bot posted + false, // is_mentioned + true, // in_thread + true, // involved + true, // other_bot_present ← another bot posted )); } @@ -1405,10 +1523,10 @@ mod tests { fn multibot_mentions_multi_bot_thread_with_mention() { assert!(should_process_user_message( AllowUsers::MultibotMentions, - true, // is_mentioned - true, // in_thread - true, // involved - true, // other_bot_present + true, // is_mentioned + true, // in_thread + true, // involved + true, // other_bot_present )); } @@ -1419,10 +1537,10 @@ mod tests { fn multibot_mentions_main_channel_no_mention() { assert!(!should_process_user_message( AllowUsers::MultibotMentions, - false, // is_mentioned - false, // in_thread (main channel) - false, // involved - false, // other_bot_present + false, // is_mentioned + false, // in_thread (main channel) + false, // involved + false, // other_bot_present )); } @@ -1433,10 +1551,10 @@ mod tests { fn multibot_mentions_not_involved() { assert!(!should_process_user_message( AllowUsers::MultibotMentions, - false, // is_mentioned - true, // in_thread - false, // involved ← bot hasn't posted here - false, // other_bot_present + false, // is_mentioned + true, // in_thread + false, // involved ← bot hasn't posted here + false, // other_bot_present )); } @@ -1447,10 +1565,10 @@ mod tests { fn involved_mode_ignores_multibot() { assert!(should_process_user_message( AllowUsers::Involved, - false, // is_mentioned - true, // in_thread - true, // involved - true, // other_bot_present ← ignored in involved mode + false, // is_mentioned + true, // in_thread + true, // involved + true, // other_bot_present ← ignored in involved mode )); } @@ -1461,10 +1579,10 @@ mod tests { fn mentions_mode_always_requires_mention() { assert!(!should_process_user_message( AllowUsers::Mentions, - false, // is_mentioned - true, // in_thread - true, // involved - false, // other_bot_present + false, // is_mentioned + true, // in_thread + true, // involved + false, // other_bot_present )); } @@ -1518,7 +1636,15 @@ mod tests { /// In-thread message: channel_id = parent, thread_id = thread channel ID. #[test] fn build_sender_context_in_thread() { - let ctx = build_sender_context("user1", "alice", "Alice", "thread_ch", Some("parent_ch"), false, "2026-05-01T00:00:00Z"); + let ctx = build_sender_context( + "user1", + "alice", + "Alice", + "thread_ch", + Some("parent_ch"), + false, + "2026-05-01T00:00:00Z", + ); assert_eq!(ctx.channel_id, "parent_ch"); assert_eq!(ctx.thread_id, Some("thread_ch".to_string())); assert_eq!(ctx.channel, "discord"); @@ -1529,7 +1655,15 @@ mod tests { /// Non-thread message: channel_id = message channel, thread_id = None. #[test] fn build_sender_context_not_in_thread() { - let ctx = build_sender_context("user1", "alice", "Alice", "main_ch", None, false, "2026-05-01T00:00:00Z"); + let ctx = build_sender_context( + "user1", + "alice", + "Alice", + "main_ch", + None, + false, + "2026-05-01T00:00:00Z", + ); assert_eq!(ctx.channel_id, "main_ch"); assert_eq!(ctx.thread_id, None); } @@ -1537,7 +1671,15 @@ mod tests { /// Bot sender: is_bot flag propagated correctly. #[test] fn build_sender_context_bot_sender() { - let ctx = build_sender_context("bot1", "mybot", "MyBot", "ch", Some("parent"), true, "2026-05-01T00:00:00Z"); + let ctx = build_sender_context( + "bot1", + "mybot", + "MyBot", + "ch", + Some("parent"), + true, + "2026-05-01T00:00:00Z", + ); assert!(ctx.is_bot); assert_eq!(ctx.channel_id, "parent"); assert_eq!(ctx.thread_id, Some("ch".to_string())); @@ -1704,8 +1846,12 @@ mod tests { let category_id: u64 = 200; let allowed = HashSet::from([category_id]); // Category child: has parent_id (the category) but NO thread_metadata. - let (in_thread, _) = detect_thread(false, Some(category_id), None, 1000, &allowed, false, false); - assert!(!in_thread, "category child must not match allowed_channels via parent_id"); + let (in_thread, _) = + detect_thread(false, Some(category_id), None, 1000, &allowed, false, false); + assert!( + !in_thread, + "category child must not match allowed_channels via parent_id" + ); } // --- Per-thread streaming tests (#534) --- @@ -1825,10 +1971,10 @@ mod tests { // because is_mentioned=false and in_thread=false. assert!(!should_process_user_message( AllowUsers::Involved, - false, // is_mentioned (DMs don't have @mention) - false, // in_thread (DMs are not threads) - false, // involved - false, // other_bot_present + false, // is_mentioned (DMs don't have @mention) + false, // in_thread (DMs are not threads) + false, // involved + false, // other_bot_present )); } diff --git a/src/dispatch.rs b/src/dispatch.rs index 1667521c..a3fbec88 100644 --- a/src/dispatch.rs +++ b/src/dispatch.rs @@ -17,8 +17,8 @@ use anyhow::Result; use async_trait::async_trait; use tracing::{debug, error, info, info_span, warn}; -use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, MessageRef}; use crate::acp::ContentBlock; +use crate::adapter::{AdapterRouter, ChannelRef, ChatAdapter, MessageRef}; use crate::config::ReactionsConfig; use crate::error_display::format_user_error; use crate::reactions::StatusReactionController; @@ -196,9 +196,19 @@ pub fn dispatch_params( ) -> (usize, BatchGrouping, Duration) { use crate::config::MessageProcessingMode; match mode { - MessageProcessingMode::Message => (1, BatchGrouping::Thread, PER_MESSAGE_CONSUMER_IDLE_TIMEOUT), - MessageProcessingMode::Thread => (max_buffered, BatchGrouping::Thread, DEFAULT_CONSUMER_IDLE_TIMEOUT), - MessageProcessingMode::Lane => (max_buffered, BatchGrouping::Lane, DEFAULT_CONSUMER_IDLE_TIMEOUT), + MessageProcessingMode::Message => { + (1, BatchGrouping::Thread, PER_MESSAGE_CONSUMER_IDLE_TIMEOUT) + } + MessageProcessingMode::Thread => ( + max_buffered, + BatchGrouping::Thread, + DEFAULT_CONSUMER_IDLE_TIMEOUT, + ), + MessageProcessingMode::Lane => ( + max_buffered, + BatchGrouping::Lane, + DEFAULT_CONSUMER_IDLE_TIMEOUT, + ), } } @@ -394,7 +404,10 @@ impl Dispatcher { let _ = adapter .send_message( &thread_channel, - &format!("⚠️ {}", format_user_error("dispatch consumer exited unexpectedly")), + &format!( + "⚠️ {}", + format_user_error("dispatch consumer exited unexpectedly") + ), ) .await; return Err(DispatchError::ConsumerDead); @@ -740,11 +753,8 @@ mod tests { #[test] fn pack_arrival_event_single() { - let blocks = AdapterRouter::pack_arrival_event( - r#"{"schema":"openab.sender.v1"}"#, - "hello", - vec![], - ); + let blocks = + AdapterRouter::pack_arrival_event(r#"{"schema":"openab.sender.v1"}"#, "hello", vec![]); // sender_context delimiter + prompt = 2 blocks assert_eq!(blocks.len(), 2); if let ContentBlock::Text { text } = &blocks[0] { @@ -765,14 +775,23 @@ mod tests { #[test] fn pack_arrival_event_with_extra_blocks() { let extra = vec![ - ContentBlock::Text { text: "[Voice transcript]: hi".into() }, - ContentBlock::Image { media_type: "image/png".into(), data: "abc".into() }, + ContentBlock::Text { + text: "[Voice transcript]: hi".into(), + }, + ContentBlock::Image { + media_type: "image/png".into(), + data: "abc".into(), + }, ]; let blocks = AdapterRouter::pack_arrival_event("{}", "prompt", extra); // delimiter + transcript + prompt + image = 4 blocks assert_eq!(blocks.len(), 4); - assert!(matches!(&blocks[0], ContentBlock::Text { text } if text.contains(""))); - assert!(matches!(&blocks[1], ContentBlock::Text { text } if text.contains("Voice transcript"))); + assert!( + matches!(&blocks[0], ContentBlock::Text { text } if text.contains("")) + ); + assert!( + matches!(&blocks[1], ContentBlock::Text { text } if text.contains("Voice transcript")) + ); assert!(matches!(&blocks[2], ContentBlock::Text { text } if text == "prompt")); assert!(matches!(&blocks[3], ContentBlock::Image { .. })); } @@ -781,8 +800,16 @@ mod tests { fn pack_arrival_event_batch_n2() { // Two arrival events concatenated → 2 (header + prompt) pairs = 4 blocks. let mut all: Vec = Vec::new(); - all.extend(AdapterRouter::pack_arrival_event(r#"{"ts":"T1"}"#, "msg1", vec![])); - all.extend(AdapterRouter::pack_arrival_event(r#"{"ts":"T2"}"#, "msg2", vec![])); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"ts":"T1"}"#, + "msg1", + vec![], + )); + all.extend(AdapterRouter::pack_arrival_event( + r#"{"ts":"T2"}"#, + "msg2", + vec![], + )); assert_eq!(all.len(), 4); if let ContentBlock::Text { text } = &all[0] { assert!(text.contains(r#""ts":"T1""#)); @@ -1095,7 +1122,7 @@ mod tests { insert_dummy_handle(&d, "discord:T1:userA"); insert_dummy_handle(&d, "discord:T1:userB"); insert_dummy_handle(&d, "discord:T2:userA"); // different thread - insert_dummy_handle(&d, "slack:T1:userA"); // different platform + insert_dummy_handle(&d, "slack:T1:userA"); // different platform d.cancel_buffered_thread("discord", "T1"); let map = d.per_thread.lock().unwrap(); assert!(!map.contains_key("discord:T1:userA")); @@ -1268,11 +1295,18 @@ mod tests { #[async_trait] impl ChatAdapter for MockChatAdapter { - fn platform(&self) -> &'static str { "mock" } - fn message_limit(&self) -> usize { 2000 } + fn platform(&self) -> &'static str { + "mock" + } + fn message_limit(&self) -> usize { + 2000 + } async fn send_message(&self, channel: &ChannelRef, _content: &str) -> Result { - Ok(MessageRef { channel: channel.clone(), message_id: "mock-msg".into() }) + Ok(MessageRef { + channel: channel.clone(), + message_id: "mock-msg".into(), + }) } async fn create_thread( @@ -1284,9 +1318,15 @@ mod tests { Ok(channel.clone()) } - async fn add_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { Ok(()) } - async fn remove_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { Ok(()) } - fn use_streaming(&self, _other_bot_present: bool) -> bool { false } + async fn add_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { + Ok(()) + } + async fn remove_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { + Ok(()) + } + fn use_streaming(&self, _other_bot_present: bool) -> bool { + false + } } fn make_channel(thread: &str) -> ChannelRef { @@ -1301,7 +1341,8 @@ mod tests { fn make_msg(prompt: &str, tokens: usize) -> BufferedMessage { BufferedMessage { - sender_json: r#"{"schema":"openab.sender.v1","sender_id":"u","sender_name":"u"}"#.into(), + sender_json: r#"{"schema":"openab.sender.v1","sender_id":"u","sender_name":"u"}"# + .into(), sender_name: "u".into(), prompt: prompt.into(), extra_blocks: vec![], @@ -1403,7 +1444,10 @@ mod tests { )); // Wait enough for the timeout branch + a tick for the task to finish. tokio::time::sleep(Duration::from_millis(150)).await; - assert!(consumer.is_finished(), "consumer should exit after idle timeout"); + assert!( + consumer.is_finished(), + "consumer should exit after idle timeout" + ); // No dispatches should have been recorded. assert!(mock.calls().is_empty()); drop(tx); @@ -1417,7 +1461,13 @@ mod tests { // whose consumer is still parked but whose rx has been dropped. let mock = Arc::new(MockDispatchTarget::new()); let target: Arc = mock.clone(); - let d = Dispatcher::with_idle_timeout(target, 10, 24_000, BatchGrouping::Thread, DEFAULT_CONSUMER_IDLE_TIMEOUT); + let d = Dispatcher::with_idle_timeout( + target, + 10, + 24_000, + BatchGrouping::Thread, + DEFAULT_CONSUMER_IDLE_TIMEOUT, + ); let adapter: Arc = Arc::new(MockChatAdapter); let key = "mock:T".to_string(); @@ -1444,7 +1494,11 @@ mod tests { tokio::time::sleep(Duration::from_millis(50)).await; let calls = mock.calls(); - assert_eq!(calls.len(), 1, "fresh consumer should have dispatched the retry"); + assert_eq!( + calls.len(), + 1, + "fresh consumer should have dispatched the retry" + ); // pack_arrival_event with no extra_blocks → delimiter + prompt = 2 blocks. assert_eq!(calls[0].block_count, 2); diff --git a/src/error_display.rs b/src/error_display.rs index 40f1479a..b4e3a850 100644 --- a/src/error_display.rs +++ b/src/error_display.rs @@ -16,18 +16,25 @@ pub fn format_user_error(message: &str) -> String { if let Some(start) = msg_lower.find("timeout waiting for ") { let rest = &message[start + "timeout waiting for ".len()..]; let method = rest.split_whitespace().next().unwrap_or("request"); - return format!("**Request Timeout**\nTimeout waiting for {}, please try again.", method); + return format!( + "**Request Timeout**\nTimeout waiting for {}, please try again.", + method + ); } - return "**Request Timeout**\nTimeout waiting for a response, please try again.".to_string(); + return "**Request Timeout**\nTimeout waiting for a response, please try again." + .to_string(); } if msg_lower.contains("connection closed") || msg_lower.contains("channel closed") { - return "**Connection Lost**\nThe connection to the agent was lost, please try again.".to_string(); + return "**Connection Lost**\nThe connection to the agent was lost, please try again." + .to_string(); } if msg_lower.contains("failed to spawn") || msg_lower.contains("no such file") { - return "**Agent Not Found**\nCould not start the agent — please check your configuration.".to_string(); + return "**Agent Not Found**\nCould not start the agent — please check your configuration." + .to_string(); } if msg_lower.contains("pool exhausted") { - return "**Service Busy**\nAll agent sessions are in use, please try again shortly.".to_string(); + return "**Service Busy**\nAll agent sessions are in use, please try again shortly." + .to_string(); } if msg_lower.contains("invalid api key") || msg_lower.contains("unauthorized") { return "**Unauthorized**\nPlease check your API key configuration.".to_string(); diff --git a/src/gateway.rs b/src/gateway.rs index 8aed6aab..d8fa967c 100644 --- a/src/gateway.rs +++ b/src/gateway.rs @@ -107,12 +107,16 @@ struct GatewayResponse { // --- GatewayAdapter: ChatAdapter over WebSocket --- type PendingRequests = Arc>>>; -type SharedWsTx = Arc, +type SharedWsTx = Arc< + Mutex< + futures_util::stream::SplitSink< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + Message, + >, >, - Message, ->>>; +>; pub struct GatewayAdapter { ws_tx: SharedWsTx, @@ -263,10 +267,7 @@ async fn handle_config_command( Err(e) => Some(format!("❌ Failed to switch: {e}")), }; } else { - return Some(format!( - "⚠️ Invalid number. Use 1–{}.", - all_values.len() - )); + return Some(format!("⚠️ Invalid number. Use 1–{}.", all_values.len())); } } // Exact match on value or name @@ -548,8 +549,12 @@ pub async fn run_gateway_adapter( let (ws_tx, mut ws_rx) = ws_stream.split(); let ws_tx: SharedWsTx = Arc::new(Mutex::new(ws_tx)); let pending: PendingRequests = Arc::new(Mutex::new(HashMap::new())); - let adapter: Arc = - Arc::new(GatewayAdapter::new(ws_tx.clone(), pending.clone(), platform, streaming)); + let adapter: Arc = Arc::new(GatewayAdapter::new( + ws_tx.clone(), + pending.clone(), + platform, + streaming, + )); let slash_ws_tx = ws_tx.clone(); // for fire-and-forget slash command responses let mut tasks: tokio::task::JoinSet<()> = tokio::task::JoinSet::new(); @@ -792,4 +797,3 @@ pub async fn run_gateway_adapter( backoff_secs = (backoff_secs * 2).min(MAX_BACKOFF); } // outer reconnect loop } - diff --git a/src/main.rs b/src/main.rs index 04a0937f..3cfce2db 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,13 +7,13 @@ mod discord; mod dispatch; mod error_display; mod format; +mod gateway; mod markdown; mod media; mod reactions; mod setup; mod slack; mod stt; -mod gateway; mod timestamp; use adapter::AdapterRouter; @@ -85,7 +85,9 @@ async fn main() -> anyhow::Result<()> { ) .init(); - let cmd = Cli::parse().command.unwrap_or(Commands::Run { config: None }); + let cmd = Cli::parse() + .command + .unwrap_or(Commands::Run { config: None }); let config_arg = match cmd { Commands::Setup { output } => { @@ -117,7 +119,9 @@ async fn main() -> anyhow::Result<()> { ); if cfg.discord.is_none() && cfg.slack.is_none() && cfg.gateway.is_none() { - anyhow::bail!("no adapter configured — add [discord], [slack], and/or [gateway] to config.toml"); + anyhow::bail!( + "no adapter configured — add [discord], [slack], and/or [gateway] to config.toml" + ); } let pool = Arc::new(acp::SessionPool::new(cfg.agent, cfg.pool.max_sessions)); @@ -139,7 +143,11 @@ async fn main() -> anyhow::Result<()> { info!(model = %cfg.stt.model, base_url = %cfg.stt.base_url, "STT enabled"); } - let router = Arc::new(AdapterRouter::new(pool.clone(), cfg.reactions, cfg.markdown.tables)); + let router = Arc::new(AdapterRouter::new( + pool.clone(), + cfg.reactions, + cfg.markdown.tables, + )); // Shutdown signal for Slack adapter let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); @@ -166,25 +174,36 @@ async fn main() -> anyhow::Result<()> { }); // Pre-build shared adapters for cron scheduler (avoids duplicate Http clients / rate-limit buckets) - let shared_discord_adapter: Option> = cfg.discord.as_ref().map(|dc| { - let http = Arc::new(serenity::http::Http::new(&dc.bot_token)); - Arc::new(discord::DiscordAdapter::new(http)) as Arc - }); + let shared_discord_adapter: Option> = + cfg.discord.as_ref().map(|dc| { + let http = Arc::new(serenity::http::Http::new(&dc.bot_token)); + Arc::new(discord::DiscordAdapter::new(http)) as Arc + }); let session_ttl_dur = std::time::Duration::from_secs(ttl_secs); let shared_slack_adapter: Option> = cfg.slack.as_ref().map(|s| { - Arc::new(slack::SlackAdapter::new(s.bot_token.clone(), session_ttl_dur, s.allow_bot_messages)) + Arc::new(slack::SlackAdapter::new( + s.bot_token.clone(), + session_ttl_dur, + s.allow_bot_messages, + )) }); // Validate cronjob config at startup (fail-fast on bad cron expressions or timezones) let mut configured_platforms: Vec<&str> = Vec::new(); - if cfg.discord.is_some() { configured_platforms.push("discord"); } - if cfg.slack.is_some() { configured_platforms.push("slack"); } + if cfg.discord.is_some() { + configured_platforms.push("discord"); + } + if cfg.slack.is_some() { + configured_platforms.push("slack"); + } cron::validate_cronjobs(&cfg.cron.jobs, &configured_platforms)?; // Spawn Slack adapter (background task) let slack_handle = if let Some(slack_cfg) = cfg.slack { - let allow_all_channels = config::resolve_allow_all(slack_cfg.allow_all_channels, &slack_cfg.allowed_channels); - let allow_all_users = config::resolve_allow_all(slack_cfg.allow_all_users, &slack_cfg.allowed_users); + let allow_all_channels = + config::resolve_allow_all(slack_cfg.allow_all_channels, &slack_cfg.allowed_channels); + let allow_all_users = + config::resolve_allow_all(slack_cfg.allow_all_users, &slack_cfg.allowed_users); if !allow_all_channels && slack_cfg.allowed_channels.is_empty() { warn!("allow_all_channels=false with empty allowed_channels for Slack — bot will deny all channels"); } @@ -201,7 +220,9 @@ async fn main() -> anyhow::Result<()> { let stt = cfg.stt.clone(); let max_bot_turns = slack_cfg.max_bot_turns; let slack_shutdown_rx = shutdown_rx.clone(); - let adapter = shared_slack_adapter.clone().expect("shared_slack_adapter must exist when slack config is present"); + let adapter = shared_slack_adapter + .clone() + .expect("shared_slack_adapter must exist when slack config is present"); // Dispatcher is the sole serialization path for all modes. Message = cap 1 // (each message dispatches alone, FIFO). Thread / Lane = configured cap; // grouping decides whether senders share a buffer or get their own lane. @@ -264,15 +285,23 @@ async fn main() -> anyhow::Result<()> { platform: gw_cfg.platform, token: gw_cfg.token, bot_username: gw_cfg.bot_username, - allow_all_channels: config::resolve_allow_all(gw_cfg.allow_all_channels, &gw_cfg.allowed_channels), + allow_all_channels: config::resolve_allow_all( + gw_cfg.allow_all_channels, + &gw_cfg.allowed_channels, + ), allowed_channels: gw_cfg.allowed_channels, - allow_all_users: config::resolve_allow_all(gw_cfg.allow_all_users, &gw_cfg.allowed_users), + allow_all_users: config::resolve_allow_all( + gw_cfg.allow_all_users, + &gw_cfg.allowed_users, + ), allowed_users: gw_cfg.allowed_users, streaming: gw_cfg.streaming, }; let gw_router = router.clone(); Some(tokio::spawn(async move { - if let Err(e) = gateway::run_gateway_adapter(params, shutdown_rx, gw_dispatcher, gw_router).await { + if let Err(e) = + gateway::run_gateway_adapter(params, shutdown_rx, gw_dispatcher, gw_router).await + { error!("gateway adapter error: {e}"); } })) @@ -311,10 +340,19 @@ async fn main() -> anyhow::Result<()> { if let Some(ref a) = shared_slack_adapter { cron_adapters.insert("slack".into(), a.clone() as Arc); } - let cron_platforms: Vec = configured_platforms.iter().map(|s| s.to_string()).collect(); + let cron_platforms: Vec = + configured_platforms.iter().map(|s| s.to_string()).collect(); info!(baseline = cronjobs.len(), usercron = ?usercron_path, "starting cron scheduler"); Some(tokio::spawn(async move { - cron::run_scheduler(cronjobs, usercron_path, cron_platforms, cron_router, cron_adapters, shutdown_rx).await; + cron::run_scheduler( + cronjobs, + usercron_path, + cron_platforms, + cron_router, + cron_adapters, + shutdown_rx, + ) + .await; })) } else { None @@ -322,15 +360,20 @@ async fn main() -> anyhow::Result<()> { // Run Discord adapter (foreground, blocking) or wait for ctrl_c if let Some(discord_cfg) = cfg.discord { - let allow_all_channels = config::resolve_allow_all(discord_cfg.allow_all_channels, &discord_cfg.allowed_channels); - let allow_all_users = config::resolve_allow_all(discord_cfg.allow_all_users, &discord_cfg.allowed_users); + let allow_all_channels = config::resolve_allow_all( + discord_cfg.allow_all_channels, + &discord_cfg.allowed_channels, + ); + let allow_all_users = + config::resolve_allow_all(discord_cfg.allow_all_users, &discord_cfg.allowed_users); let allowed_channels = parse_id_set(&discord_cfg.allowed_channels, "discord.allowed_channels")?; if !allow_all_channels && allowed_channels.is_empty() { warn!("allow_all_channels=false with empty allowed_channels for Discord — bot will deny all channels"); } let allowed_users = parse_id_set(&discord_cfg.allowed_users, "discord.allowed_users")?; - let trusted_bot_ids = parse_id_set(&discord_cfg.trusted_bot_ids, "discord.trusted_bot_ids")?; + let trusted_bot_ids = + parse_id_set(&discord_cfg.trusted_bot_ids, "discord.trusted_bot_ids")?; info!( allow_all_channels, allow_all_users, @@ -371,7 +414,9 @@ async fn main() -> anyhow::Result<()> { multibot_threads: tokio::sync::Mutex::new(std::collections::HashMap::new()), session_ttl: std::time::Duration::from_secs(ttl_secs), max_bot_turns: discord_cfg.max_bot_turns, - bot_turns: tokio::sync::Mutex::new(bot_turns::BotTurnTracker::new(discord_cfg.max_bot_turns)), + bot_turns: tokio::sync::Mutex::new(bot_turns::BotTurnTracker::new( + discord_cfg.max_bot_turns, + )), allow_dm: discord_cfg.allow_dm, dispatcher: discord_dispatcher, }; @@ -503,7 +548,8 @@ mod tests { #[test] fn cli_run_with_remote_url() { - let cli = Cli::try_parse_from(["openab", "run", "-c", "https://example.com/config.toml"]).unwrap(); + let cli = Cli::try_parse_from(["openab", "run", "-c", "https://example.com/config.toml"]) + .unwrap(); match cli.command.unwrap() { Commands::Run { config } => assert!(config.unwrap().starts_with("https://")), _ => panic!("expected Run"), diff --git a/src/markdown.rs b/src/markdown.rs index 6b0aa533..32398cc2 100644 --- a/src/markdown.rs +++ b/src/markdown.rs @@ -330,9 +330,7 @@ Some text after. // The table is inside a ``` block — backtick wrapping must be stripped. assert!(result.contains("value"), "cell content should be present"); // Only the fence markers themselves should contain backticks. - let inner = result - .trim_start_matches("```\n") - .trim_end_matches("```\n"); + let inner = result.trim_start_matches("```\n").trim_end_matches("```\n"); assert!( !inner.contains('`'), "no backticks should appear inside the code fence: {result:?}" @@ -343,6 +341,9 @@ Some text after. fn bullets_mode_keeps_backticks_in_code_cells() { let md = "| col |\n|-----|\n| `value` |\n"; let result = convert_tables(md, TableMode::Bullets); - assert!(result.contains("`value`"), "backticks should be kept in bullets mode"); + assert!( + result.contains("`value`"), + "backticks should be kept in bullets mode" + ); } } diff --git a/src/media.rs b/src/media.rs index 5e0c057f..aa56e5f4 100644 --- a/src/media.rs +++ b/src/media.rs @@ -71,7 +71,10 @@ pub async fn download_and_encode_image( let response = match req.send().await { Ok(resp) => resp, - Err(e) => { error!(url, error = %e, "download failed"); return None; } + Err(e) => { + error!(url, error = %e, "download failed"); + return None; + } }; if !response.status().is_success() { error!(url, status = %response.status(), "HTTP error downloading image"); @@ -79,11 +82,18 @@ pub async fn download_and_encode_image( } let bytes = match response.bytes().await { Ok(b) => b, - Err(e) => { error!(url, error = %e, "read failed"); return None; } + Err(e) => { + error!(url, error = %e, "read failed"); + return None; + } }; if bytes.len() as u64 > MAX_SIZE { - error!(filename, size = bytes.len(), "downloaded image exceeds limit"); + error!( + filename, + size = bytes.len(), + "downloaded image exceeds limit" + ); return None; } @@ -142,14 +152,20 @@ pub async fn download_and_transcribe( } let bytes = resp.bytes().await.ok()?.to_vec(); - crate::stt::transcribe(&HTTP_CLIENT, stt_config, bytes, filename.to_string(), mime_type).await + crate::stt::transcribe( + &HTTP_CLIENT, + stt_config, + bytes, + filename.to_string(), + mime_type, + ) + .await } /// Resize image so longest side <= IMAGE_MAX_DIMENSION_PX, then encode as JPEG. /// GIFs are passed through unchanged to preserve animation. pub fn resize_and_compress(raw: &[u8]) -> Result<(Vec, String), image::ImageError> { - let reader = ImageReader::new(Cursor::new(raw)) - .with_guessed_format()?; + let reader = ImageReader::new(Cursor::new(raw)).with_guessed_format()?; let format = reader.format(); @@ -184,16 +200,23 @@ pub fn is_audio_mime(mime: &str) -> bool { /// Extensions recognised as text-based files that can be inlined into the prompt. const TEXT_EXTENSIONS: &[&str] = &[ - "txt", "csv", "log", "md", "json", "jsonl", "yaml", "yml", "toml", "xml", - "rs", "py", "js", "ts", "jsx", "tsx", "go", "java", "c", "cpp", "h", "hpp", - "rb", "sh", "bash", "zsh", "fish", "ps1", "bat", "sql", "html", "css", - "scss", "less", "ini", "cfg", "conf", "env", + "txt", "csv", "log", "md", "json", "jsonl", "yaml", "yml", "toml", "xml", "rs", "py", "js", + "ts", "jsx", "tsx", "go", "java", "c", "cpp", "h", "hpp", "rb", "sh", "bash", "zsh", "fish", + "ps1", "bat", "sql", "html", "css", "scss", "less", "ini", "cfg", "conf", "env", ]; /// Exact filenames (no extension) recognised as text files. const TEXT_FILENAMES: &[&str] = &[ - "dockerfile", "makefile", "justfile", "rakefile", "gemfile", - "procfile", "vagrantfile", ".gitignore", ".dockerignore", ".editorconfig", + "dockerfile", + "makefile", + "justfile", + "rakefile", + "gemfile", + "procfile", + "vagrantfile", + ".gitignore", + ".dockerignore", + ".editorconfig", ]; /// MIME types recognised as text-based (beyond `text/*`). @@ -268,7 +291,11 @@ pub async fn download_and_read_text_file( // Defense-in-depth: verify actual download size if actual_size > MAX_SIZE { - tracing::warn!(filename, size = actual_size, "downloaded text file exceeds 512KB limit, skipping"); + tracing::warn!( + filename, + size = actual_size, + "downloaded text file exceeds 512KB limit, skipping" + ); return None; } @@ -348,16 +375,19 @@ mod tests { let png = make_png(3000, 2000); let (compressed, _) = resize_and_compress(&png).unwrap(); - assert!(compressed.len() < png.len(), "compressed {} should be < original {}", compressed.len(), png.len()); + assert!( + compressed.len() < png.len(), + "compressed {} should be < original {}", + compressed.len(), + png.len() + ); } #[test] fn gif_passes_through_unchanged() { let gif: Vec = vec![ - 0x47, 0x49, 0x46, 0x38, 0x39, 0x61, - 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x2C, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, - 0x02, 0x02, 0x44, 0x01, 0x00, + 0x47, 0x49, 0x46, 0x38, 0x39, 0x61, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x2C, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x02, 0x02, 0x44, 0x01, 0x00, 0x3B, ]; let (output, mime) = resize_and_compress(&gif).unwrap(); diff --git a/src/reactions.rs b/src/reactions.rs index 8638d86f..6e68f90b 100644 --- a/src/reactions.rs +++ b/src/reactions.rs @@ -5,7 +5,13 @@ use tokio::sync::Mutex; use tokio::time::Duration; const CODING_TOKENS: &[&str] = &["exec", "process", "read", "write", "edit", "bash", "shell"]; -const WEB_TOKENS: &[&str] = &["web_search", "web_fetch", "web-search", "web-fetch", "browser"]; +const WEB_TOKENS: &[&str] = &[ + "web_search", + "web_fetch", + "web-search", + "web-fetch", + "browser", +]; fn classify_tool<'a>(name: &str, emojis: &'a ReactionEmojis) -> &'a str { let n = name.to_lowercase(); @@ -60,19 +66,25 @@ impl StatusReactionController { } pub async fn set_queued(&self) { - if !self.enabled { return; } + if !self.enabled { + return; + } let emoji = { self.inner.lock().await.emojis.queued.clone() }; self.apply_immediate(&emoji).await; } pub async fn set_thinking(&self) { - if !self.enabled { return; } + if !self.enabled { + return; + } let emoji = { self.inner.lock().await.emojis.thinking.clone() }; self.schedule_debounced(&emoji).await; } pub async fn set_tool(&self, tool_name: &str) { - if !self.enabled { return; } + if !self.enabled { + return; + } let emoji = { let inner = self.inner.lock().await; classify_tool(tool_name, &inner.emojis).to_string() @@ -81,7 +93,9 @@ impl StatusReactionController { } pub async fn set_done(&self) { - if !self.enabled { return; } + if !self.enabled { + return; + } let emoji = { self.inner.lock().await.emojis.done.clone() }; self.finish(&emoji).await; // Add a random mood face @@ -92,18 +106,25 @@ impl StatusReactionController { } pub async fn set_error(&self) { - if !self.enabled { return; } + if !self.enabled { + return; + } let emoji = { self.inner.lock().await.emojis.error.clone() }; self.finish(&emoji).await; } pub async fn clear(&self) { - if !self.enabled { return; } + if !self.enabled { + return; + } let mut inner = self.inner.lock().await; cancel_timers(&mut inner); let current = inner.current.clone(); if !current.is_empty() { - let _ = inner.adapter.remove_reaction(&inner.message, ¤t).await; + let _ = inner + .adapter + .remove_reaction(&inner.message, ¤t) + .await; inner.current.clear(); } } @@ -142,7 +163,9 @@ impl StatusReactionController { inner.debounce_handle = Some(tokio::spawn(async move { tokio::time::sleep(Duration::from_millis(debounce_ms)).await; let mut inner = ctrl.lock().await; - if inner.finished { return; } + if inner.finished { + return; + } let old = inner.current.clone(); inner.current = emoji.clone(); let adapter = inner.adapter.clone(); @@ -159,7 +182,9 @@ impl StatusReactionController { async fn finish(&self, emoji: &str) { let mut inner = self.inner.lock().await; - if inner.finished { return; } + if inner.finished { + return; + } inner.finished = true; cancel_timers(&mut inner); @@ -182,8 +207,12 @@ impl StatusReactionController { } fn reset_stall_timers_inner(&self, inner: &mut Inner) { - if let Some(h) = inner.stall_soft_handle.take() { h.abort(); } - if let Some(h) = inner.stall_hard_handle.take() { h.abort(); } + if let Some(h) = inner.stall_soft_handle.take() { + h.abort(); + } + if let Some(h) = inner.stall_hard_handle.take() { + h.abort(); + } let soft_ms = inner.timing.stall_soft_ms; let hard_ms = inner.timing.stall_hard_ms; @@ -194,7 +223,9 @@ impl StatusReactionController { async move { tokio::time::sleep(Duration::from_millis(soft_ms)).await; let mut inner = ctrl.lock().await; - if inner.finished { return; } + if inner.finished { + return; + } let old = inner.current.clone(); inner.current = "🥱".to_string(); let adapter = inner.adapter.clone(); @@ -210,7 +241,9 @@ impl StatusReactionController { inner.stall_hard_handle = Some(tokio::spawn(async move { tokio::time::sleep(Duration::from_millis(hard_ms)).await; let mut inner = ctrl.lock().await; - if inner.finished { return; } + if inner.finished { + return; + } let old = inner.current.clone(); inner.current = "😨".to_string(); let adapter = inner.adapter.clone(); @@ -225,11 +258,19 @@ impl StatusReactionController { } fn cancel_debounce(inner: &mut Inner) { - if let Some(h) = inner.debounce_handle.take() { h.abort(); } + if let Some(h) = inner.debounce_handle.take() { + h.abort(); + } } fn cancel_timers(inner: &mut Inner) { - if let Some(h) = inner.debounce_handle.take() { h.abort(); } - if let Some(h) = inner.stall_soft_handle.take() { h.abort(); } - if let Some(h) = inner.stall_hard_handle.take() { h.abort(); } + if let Some(h) = inner.debounce_handle.take() { + h.abort(); + } + if let Some(h) = inner.stall_soft_handle.take() { + h.abort(); + } + if let Some(h) = inner.stall_hard_handle.take() { + h.abort(); + } } diff --git a/src/setup/config.rs b/src/setup/config.rs index 21d65e7e..c0e7d604 100644 --- a/src/setup/config.rs +++ b/src/setup/config.rs @@ -85,10 +85,7 @@ pub fn generate_config( }, agent: { let (command, args): (&str, Vec) = match agent_command { - "kiro" => ( - "kiro-cli", - vec!["acp".into(), "--trust-all-tools".into()], - ), + "kiro" => ("kiro-cli", vec!["acp".into(), "--trust-all-tools".into()]), "claude" => ("claude-agent-acp", vec![]), "codex" => ("codex-acp", vec![]), "gemini" => ("gemini", vec!["--acp".into()]), @@ -152,14 +149,7 @@ mod tests { #[test] fn test_generate_config_kiro_working_dir() { - let config = generate_config( - "tok", - "kiro", - vec!["ch".to_string()], - "/home/agent", - 10, - 24, - ); + let config = generate_config("tok", "kiro", vec!["ch".to_string()], "/home/agent", 10, 24); assert!(config.contains(r#"working_dir = "/home/agent""#)); assert!(config.contains("acp")); assert!(config.contains("--trust-all-tools")); diff --git a/src/setup/validate.rs b/src/setup/validate.rs index 247b1b9a..527a1a38 100644 --- a/src/setup/validate.rs +++ b/src/setup/validate.rs @@ -5,10 +5,15 @@ pub fn validate_bot_token(token: &str) -> anyhow::Result<()> { if token.is_empty() { anyhow::bail!("Token cannot be empty"); } - if !token - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '/' || c == '*' || c == '=') - { + if !token.chars().all(|c| { + c.is_ascii_alphanumeric() + || c == '-' + || c == '.' + || c == '_' + || c == '/' + || c == '*' + || c == '=' + }) { anyhow::bail!( "Token must only contain ASCII letters, numbers, dash, period, underscore, slash, or equals" ); diff --git a/src/setup/wizard.rs b/src/setup/wizard.rs index e8751172..f5a78960 100644 --- a/src/setup/wizard.rs +++ b/src/setup/wizard.rs @@ -154,7 +154,11 @@ fn print_box(lines: &[&str]) { .unwrap_or(60); let width = width.clamp(60, 76); println!(); - cprintln!(C.cyan, "{}", "╔".to_string() + &BORDER.to_string().repeat(width + 2) + "╗"); + cprintln!( + C.cyan, + "{}", + "╔".to_string() + &BORDER.to_string().repeat(width + 2) + "╗" + ); for line in lines { let padded = format!(" {: anyhow::Result> { println!(); if guilds.is_empty() { - cprintln!( - C.yellow, - " No servers found. Enter channel IDs manually." - ); + cprintln!(C.yellow, " No servers found. Enter channel IDs manually."); let input = prompt(" Channel ID(s), comma-separated"); let ids: Vec = input .split(',') @@ -342,21 +347,11 @@ fn section_channels(client: &DiscordClient) -> anyhow::Result> { return Ok(ids); } - let channel_names: Vec = channels - .iter() - .map(|(_, n, _)| format!("#{}", n)) - .collect(); - let channel_names_refs: Vec<&str> = channel_names - .iter() - .map(|s| s.as_str()) - .collect(); + let channel_names: Vec = channels.iter().map(|(_, n, _)| format!("#{}", n)).collect(); + let channel_names_refs: Vec<&str> = channel_names.iter().map(|s| s.as_str()).collect(); - let selected = - prompt_checklist(" Select channels (by number):", &channel_names_refs); - let selected_ids: Vec = selected - .iter() - .map(|&i| channels[i].0.clone()) - .collect(); + let selected = prompt_checklist(" Select channels (by number):", &channel_names_refs); + let selected_ids: Vec = selected.iter().map(|&i| channels[i].0.clone()).collect(); println!(); cprintln!(C.green, " Selected {} channel(s)", selected_ids.len()); @@ -408,12 +403,7 @@ fn section_agent() -> (String, String, bool) { let working_dir = prompt_default(" Working directory", default_dir); - cprintln!( - C.green, - " Agent: {} | Working dir: {}", - agent, - working_dir - ); + cprintln!(C.green, " Agent: {} | Working dir: {}", agent, working_dir); println!(); (agent.to_string(), working_dir, is_local) @@ -428,9 +418,7 @@ fn section_pool() -> (usize, u64) { cprintln!(C.bold, "--- Step 4: Session Pool ---"); println!(); - let max_sessions: usize = prompt_default(" Max sessions", "10") - .parse() - .unwrap_or(10); + let max_sessions: usize = prompt_default(" Max sessions", "10").parse().unwrap_or(10); let ttl_hours: u64 = prompt_default(" Session TTL (hours)", "24") .parse() .unwrap_or(24); @@ -457,9 +445,7 @@ fn section_preview_and_save(config_content: &str, output_path: &PathBuf) -> anyh println!("{}", mask_bot_token(config_content)); println!(); - if output_path.exists() - && !prompt_yes_no(" File exists. Overwrite?", false) - { + if output_path.exists() && !prompt_yes_no(" File exists. Overwrite?", false) { println!(" Saving cancelled."); return Ok(()); } @@ -517,7 +503,10 @@ fn print_next_steps(agent: &str, output_path: &Path, is_local: bool) { if is_local { match agent { "kiro" => { - cprintln!(C.cyan, " 1. Install kiro-cli (see https://kiro.dev for installer)"); + cprintln!( + C.cyan, + " 1. Install kiro-cli (see https://kiro.dev for installer)" + ); cprintln!(C.cyan, " 2. Authenticate:"); println!(" kiro-cli login --use-device-flow"); } @@ -536,7 +525,10 @@ fn print_next_steps(agent: &str, output_path: &Path, is_local: bool) { "gemini" => { cprintln!(C.cyan, " 1. Install Gemini CLI:"); println!(" npm install -g @google/gemini-cli"); - cprintln!(C.cyan, " 2. Authenticate via Google OAuth, or set GEMINI_API_KEY in config.toml"); + cprintln!( + C.cyan, + " 2. Authenticate via Google OAuth, or set GEMINI_API_KEY in config.toml" + ); } _ => {} } @@ -552,22 +544,28 @@ fn print_next_steps(agent: &str, output_path: &Path, is_local: bool) { println!(); cprintln!(C.cyan, " 1. Deploy with Helm (or your preferred method):"); println!(" helm install openab openab/openab \\"); - println!(" --set agents.{}.discord.botToken=\"$BOT_TOKEN\"", agent); + println!( + " --set agents.{}.discord.botToken=\"$BOT_TOKEN\"", + agent + ); println!(); - cprintln!(C.cyan, " 2. Authenticate inside the pod (first time only):"); + cprintln!( + C.cyan, + " 2. Authenticate inside the pod (first time only):" + ); match agent { "kiro" => println!( " kubectl exec -it deployment/openab-kiro -- kiro-cli login --use-device-flow" ), - "claude" => println!( - " kubectl exec -it deployment/openab-claude -- claude auth login" - ), + "claude" => { + println!(" kubectl exec -it deployment/openab-claude -- claude auth login") + } "codex" => println!( " kubectl exec -it deployment/openab-codex -- codex login --device-auth" ), - "gemini" => println!( - " Set GEMINI_API_KEY via secret, or exec into the pod for OAuth" - ), + "gemini" => { + println!(" Set GEMINI_API_KEY via secret, or exec into the pod for OAuth") + } _ => {} } println!(); @@ -605,10 +603,7 @@ pub fn run_setup(output_path: Option) -> anyhow::Result<()> { println!(); let bot_token = prompt_password(" Bot Token (or press Enter to skip)"); if bot_token.is_empty() { - cprintln!( - C.yellow, - " Skipped. Set bot_token manually in config.toml" - ); + cprintln!(C.yellow, " Skipped. Set bot_token manually in config.toml"); println!(); cprintln!( C.green, @@ -632,11 +627,7 @@ pub fn run_setup(output_path: Option) -> anyhow::Result<()> { vec![] } Err(e) => { - cprintln!( - C.yellow, - " Channel fetch failed: {}. Enter manually.", - e - ); + cprintln!(C.yellow, " Channel fetch failed: {}. Enter manually.", e); let input = prompt(" Channel ID(s), comma-separated"); let ids: Vec = input .split(',') diff --git a/src/slack.rs b/src/slack.rs index 979db52b..74d46062 100644 --- a/src/slack.rs +++ b/src/slack.rs @@ -1,5 +1,5 @@ use crate::acp::ContentBlock; -use crate::adapter::{ChatAdapter, ChannelRef, MessageRef, SenderContext}; +use crate::adapter::{ChannelRef, ChatAdapter, MessageRef, SenderContext}; use crate::bot_turns::{BotTurnTracker, TurnAction, TurnSeverity}; use crate::config::{AllowBots, AllowUsers, SttConfig}; use crate::media; @@ -70,7 +70,11 @@ pub struct SlackAdapter { } impl SlackAdapter { - pub fn new(bot_token: String, session_ttl: std::time::Duration, _allow_bot_messages: AllowBots) -> Self { + pub fn new( + bot_token: String, + session_ttl: std::time::Duration, + _allow_bot_messages: AllowBots, + ) -> Self { Self { client: reqwest::Client::new(), bot_token, @@ -93,20 +97,28 @@ impl SlackAdapter { /// depend on fetching thread history. Idempotent. async fn note_other_bot_in_thread(&self, thread_ts: &str) { let mut cache = self.multibot_threads.lock().await; - cache.entry(thread_ts.to_string()).or_insert_with(tokio::time::Instant::now); + cache + .entry(thread_ts.to_string()) + .or_insert_with(tokio::time::Instant::now); enforce_cache_bounds(&mut cache, self.session_ttl); } /// Get the bot's own Slack user ID (cached after first call). async fn get_bot_user_id(&self) -> Option<&str> { - self.bot_user_id.get_or_try_init(|| async { - let resp = self.api_post("auth.test", serde_json::json!({})).await - .map_err(|e| anyhow!("auth.test failed: {e}"))?; - resp["user_id"] - .as_str() - .map(|s| s.to_string()) - .ok_or_else(|| anyhow!("no user_id in auth.test response")) - }).await.ok().map(|s| s.as_str()) + self.bot_user_id + .get_or_try_init(|| async { + let resp = self + .api_post("auth.test", serde_json::json!({})) + .await + .map_err(|e| anyhow!("auth.test failed: {e}"))?; + resp["user_id"] + .as_str() + .map(|s| s.to_string()) + .ok_or_else(|| anyhow!("no user_id in auth.test response")) + }) + .await + .ok() + .map(|s| s.as_str()) } async fn api_post(&self, method: &str, body: serde_json::Value) -> Result { @@ -160,10 +172,7 @@ impl SlackAdapter { } let resp = self - .api_post( - "users.info", - serde_json::json!({ "user": user_id }), - ) + .api_post("users.info", serde_json::json!({ "user": user_id })) .await .ok()?; let user = resp.get("user")?; @@ -176,9 +185,7 @@ impl SlackAdapter { .get("real_name") .and_then(|v| v.as_str()) .filter(|s| !s.is_empty()); - let name = user - .get("name") - .and_then(|v| v.as_str()); + let name = user.get("name").and_then(|v| v.as_str()); let resolved = display.or(real).or(name)?.to_string(); // Cache the result @@ -204,15 +211,12 @@ impl SlackAdapter { .api_post("bots.info", serde_json::json!({ "bot": bot_id })) .await .ok()?; - let user_id = resp.get("bot")? - .get("user_id")? - .as_str()? - .to_string(); - - self.bot_id_cache.lock().await.insert( - bot_id.to_string(), - user_id.clone(), - ); + let user_id = resp.get("bot")?.get("user_id")?.as_str()?.to_string(); + + self.bot_id_cache + .lock() + .await + .insert(bot_id.to_string(), user_id.clone()); Some(user_id) } @@ -226,11 +230,15 @@ impl SlackAdapter { async fn bot_participated_in_thread(&self, channel: &str, thread_ts: &str) -> (bool, bool) { let cached_involved = { let cache = self.participated_threads.lock().await; - cache.get(thread_ts).is_some_and(|ts| ts.elapsed() < self.session_ttl) + cache + .get(thread_ts) + .is_some_and(|ts| ts.elapsed() < self.session_ttl) }; let cached_multibot = { let cache = self.multibot_threads.lock().await; - cache.get(thread_ts).is_some_and(|ts| ts.elapsed() < self.session_ttl) + cache + .get(thread_ts) + .is_some_and(|ts| ts.elapsed() < self.session_ttl) }; // Eager multibot detection from message events populates the cache @@ -266,7 +274,9 @@ impl SlackAdapter { return (false, false); } }; - let Some(messages) = json["messages"].as_array() else { return (false, false) }; + let Some(messages) = json["messages"].as_array() else { + return (false, false); + }; let parent_mentions_bot = messages .first() @@ -278,8 +288,8 @@ impl SlackAdapter { let involved = parent_mentions_bot || bot_posted; let other_bot_present = cached_multibot || messages.iter().any(|m| { - let is_bot_msg = m["bot_id"].is_string() - || m["subtype"].as_str() == Some("bot_message"); + let is_bot_msg = + m["bot_id"].is_string() || m["subtype"].as_str() == Some("bot_message"); is_bot_msg && m["user"].as_str() != Some(bot_id) }); @@ -356,7 +366,6 @@ impl ChatAdapter for SlackAdapter { }) } - async fn create_thread( &self, channel: &ChannelRef, @@ -375,15 +384,16 @@ impl ChatAdapter for SlackAdapter { async fn add_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { let name = unicode_to_slack_emoji(emoji); - match self.api_post( - "reactions.add", - serde_json::json!({ - "channel": msg.channel.channel_id, - "timestamp": msg.message_id, - "name": name, - }), - ) - .await + match self + .api_post( + "reactions.add", + serde_json::json!({ + "channel": msg.channel.channel_id, + "timestamp": msg.message_id, + "name": name, + }), + ) + .await { Ok(_) => Ok(()), Err(e) if e.to_string().contains("already_reacted") => Ok(()), @@ -393,15 +403,16 @@ impl ChatAdapter for SlackAdapter { async fn remove_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { let name = unicode_to_slack_emoji(emoji); - match self.api_post( - "reactions.remove", - serde_json::json!({ - "channel": msg.channel.channel_id, - "timestamp": msg.message_id, - "name": name, - }), - ) - .await + match self + .api_post( + "reactions.remove", + serde_json::json!({ + "channel": msg.channel.channel_id, + "timestamp": msg.message_id, + "name": name, + }), + ) + .await { Ok(_) => Ok(()), Err(e) if e.to_string().contains("no_reaction") => Ok(()), @@ -867,8 +878,8 @@ async fn handle_message( Some(u) => u.to_string(), None => return, }; - let is_bot_msg = event["bot_id"].is_string() - || event["subtype"].as_str() == Some("bot_message"); + let is_bot_msg = + event["bot_id"].is_string() || event["subtype"].as_str() == Some("bot_message"); let text = match event["text"].as_str() { Some(t) => t.to_string(), None => return, @@ -920,6 +931,7 @@ async fn handle_message( const TEXT_FILE_COUNT_CAP: u32 = 5; let mut extra_blocks = Vec::new(); + let mut echo_entries: Vec = Vec::new(); let mut text_file_bytes: u64 = 0; let mut text_file_count: u32 = 0; @@ -938,18 +950,34 @@ async fn handle_message( if media::is_audio_mime(mimetype) { if stt_config.enabled { - if let Some(transcript) = media::download_and_transcribe( + match media::download_and_transcribe( url, filename, mimetype, size, stt_config, Some(bot_token), - ).await { - debug!(filename, chars = transcript.len(), "voice transcript injected"); - extra_blocks.insert(0, ContentBlock::Text { - text: format!("[Voice message transcript]: {transcript}"), - }); + ) + .await + { + Some(transcript) => { + debug!( + filename, + chars = transcript.len(), + "voice transcript injected" + ); + extra_blocks.insert( + 0, + ContentBlock::Text { + text: format!("[Voice message transcript]: {transcript}"), + }, + ); + echo_entries.push(crate::stt::EchoEntry::Success(transcript)); + } + None => { + warn!(filename, "STT failed for voice attachment"); + echo_entries.push(crate::stt::EchoEntry::Failed); + } } } else { debug!(filename, "skipping audio attachment (STT disabled)"); @@ -967,7 +995,11 @@ async fn handle_message( } } else if media::is_text_file(filename, Some(mimetype)) { if text_file_count >= TEXT_FILE_COUNT_CAP { - debug!(filename, count = text_file_count, "text file count cap reached, skipping"); + debug!( + filename, + count = text_file_count, + "text file count cap reached, skipping" + ); continue; } // Pre-check with Slack-reported size as a fast path when the @@ -976,15 +1008,16 @@ async fn handle_message( // authoritative cap check happens after download using // `actual_bytes`. if size > 0 && text_file_bytes + size > TEXT_TOTAL_CAP { - debug!(filename, total = text_file_bytes, "text attachments total exceeds 1MB cap, skipping remaining"); + debug!( + filename, + total = text_file_bytes, + "text attachments total exceeds 1MB cap, skipping remaining" + ); continue; } - if let Some((block, actual_bytes)) = media::download_and_read_text_file( - url, - filename, - size, - Some(bot_token), - ).await { + if let Some((block, actual_bytes)) = + media::download_and_read_text_file(url, filename, size, Some(bot_token)).await + { if text_file_bytes + actual_bytes > TEXT_TOTAL_CAP { debug!( filename, @@ -1005,7 +1038,9 @@ async fn handle_message( filename, size, Some(bot_token), - ).await { + ) + .await + { debug!(filename, "adding image attachment"); extra_blocks.push(block); } @@ -1065,9 +1100,23 @@ async fn handle_message( let adapter_dyn: Arc = adapter.clone(); let other_bot_present = { let cache = adapter.multibot_threads.lock().await; - thread_channel.thread_id.as_deref() - .is_some_and(|ts| cache.get(ts).is_some_and(|inst| inst.elapsed() < adapter.session_ttl)) + thread_channel.thread_id.as_deref().is_some_and(|ts| { + cache + .get(ts) + .is_some_and(|inst| inst.elapsed() < adapter.session_ttl) + }) }; + + // Best-effort echo before the agent reply so the user can verify STT. + crate::stt::post_echo( + &adapter_dyn, + &thread_channel, + &trigger_msg, + &echo_entries, + stt_config, + ) + .await; + let thread_id = thread_channel .thread_id .as_deref() @@ -1157,12 +1206,12 @@ fn markdown_to_mrkdwn(text: &str) -> String { LazyLock::new(|| regex::Regex::new(r"```\w+\n").unwrap()); // Order: bold first (** → placeholder), then italic (* → _), then restore bold - let text = BOLD_RE.replace_all(text, "\x01$1\x02"); // **bold** → \x01bold\x02 - let text = ITALIC_RE.replace_all(&text, "_${1}_"); // *italic* → _italic_ - // Restore bold: \x01bold\x02 → *bold* + let text = BOLD_RE.replace_all(text, "\x01$1\x02"); // **bold** → \x01bold\x02 + let text = ITALIC_RE.replace_all(&text, "_${1}_"); // *italic* → _italic_ + // Restore bold: \x01bold\x02 → *bold* let text = text.replace(['\x01', '\x02'], "*"); - let text = LINK_RE.replace_all(&text, "<$2|$1>"); // [text](url) → - let text = HEADING_RE.replace_all(&text, "*$1*"); // # heading → *heading* + let text = LINK_RE.replace_all(&text, "<$2|$1>"); // [text](url) → + let text = HEADING_RE.replace_all(&text, "*$1*"); // # heading → *heading* let text = CODE_BLOCK_LANG_RE.replace_all(&text, "```\n"); // ```rust → ``` text.into_owned() } @@ -1319,7 +1368,13 @@ mod tests { let ttl = std::time::Duration::from_secs(300); let adapter = SlackAdapter::new("xoxb-test".into(), ttl, AllowBots::Mentions); - assert!(adapter.use_streaming(false), "should stream when no other bot"); - assert!(!adapter.use_streaming(true), "should NOT stream when other bot present"); + assert!( + adapter.use_streaming(false), + "should stream when no other bot" + ); + assert!( + !adapter.use_streaming(true), + "should NOT stream when other bot present" + ); } } diff --git a/src/stt.rs b/src/stt.rs index 122db9b6..d266e611 100644 --- a/src/stt.rs +++ b/src/stt.rs @@ -1,6 +1,74 @@ +use crate::adapter::{ChannelRef, ChatAdapter, MessageRef}; use crate::config::SttConfig; use reqwest::multipart; -use tracing::{debug, error}; +use std::sync::Arc; +use tracing::{debug, error, warn}; + +/// Outcome of attempting STT on a single audio attachment. +/// Used by adapters to feed `post_echo`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum EchoEntry { + Success(String), + Failed, +} + +/// Render a list of echo entries as a single multi-line quoted block. +/// Returns `None` for empty input so callers can short-circuit. +/// +/// Each entry produces one `> 🎤 …` line. Internal newlines inside a +/// transcript are flattened to spaces so each entry occupies exactly one +/// visual line — Discord and Slack both stop applying `>` at the next `\n`. +pub fn format_echo_message(entries: &[EchoEntry]) -> Option { + if entries.is_empty() { + return None; + } + let mut lines = Vec::with_capacity(entries.len()); + for e in entries { + match e { + EchoEntry::Success(text) => { + let flat = text.replace(['\n', '\r'], " "); + lines.push(format!("> 🎤 {flat}")); + } + EchoEntry::Failed => { + lines.push("> 🎤 (transcription failed)".to_string()); + } + } + } + Some(lines.join("\n")) +} + +/// Post a transcript echo to the thread and add a ⚠️ reaction for any failed +/// entries. No-op when the config disables echoing or when `entries` is empty. +/// +/// Errors from the adapter (send/reaction) are logged and swallowed — the +/// echo is best-effort and must never block the agent reply. +pub async fn post_echo( + adapter: &Arc, + thread: &ChannelRef, + trigger: &MessageRef, + entries: &[EchoEntry], + cfg: &SttConfig, +) { + if !cfg.echo_transcript { + return; + } + let Some(body) = format_echo_message(entries) else { + return; + }; + if let Err(e) = adapter.send_message(thread, &body).await { + warn!(error = %e, platform = adapter.platform(), "failed to send STT echo message"); + } + for entry in entries { + if matches!(entry, EchoEntry::Failed) { + if let Err(e) = adapter.add_reaction(trigger, "⚠️").await { + warn!(error = %e, platform = adapter.platform(), "failed to add STT failure reaction"); + } + // Add only one reaction even with multiple failures — emoji reactions + // are unique per (user, emoji, message), so additional calls are no-ops. + break; + } + } +} /// Transcribe audio bytes via an OpenAI-compatible `/audio/transcriptions` endpoint. pub async fn transcribe( @@ -10,7 +78,10 @@ pub async fn transcribe( filename: String, mime_type: &str, ) -> Option { - let url = format!("{}/audio/transcriptions", cfg.base_url.trim_end_matches('/')); + let url = format!( + "{}/audio/transcriptions", + cfg.base_url.trim_end_matches('/') + ); let file_part = multipart::Part::bytes(audio_bytes) .file_name(filename) @@ -59,3 +130,225 @@ pub async fn transcribe( debug!(chars = text.len(), "STT transcription complete"); Some(text) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn format_single_success_entry() { + let entries = vec![EchoEntry::Success("hello world".into())]; + let out = format_echo_message(&entries).expect("non-empty input → Some"); + assert_eq!(out, "> 🎤 hello world"); + } + + #[test] + fn format_single_failure_entry() { + let entries = vec![EchoEntry::Failed]; + let out = format_echo_message(&entries).expect("non-empty input → Some"); + assert_eq!(out, "> 🎤 (transcription failed)"); + } + + #[test] + fn format_multiple_mixed_entries() { + let entries = vec![ + EchoEntry::Success("first".into()), + EchoEntry::Failed, + EchoEntry::Success("third".into()), + ]; + let out = format_echo_message(&entries).expect("non-empty input → Some"); + assert_eq!(out, "> 🎤 first\n> 🎤 (transcription failed)\n> 🎤 third"); + } + + #[test] + fn format_empty_entries_returns_none() { + let entries: Vec = vec![]; + assert!(format_echo_message(&entries).is_none()); + } + + #[test] + fn format_strips_internal_newlines_in_transcript() { + // Multi-line transcripts must collapse to a single quoted line so the + // ">" prefix still applies to every visual line. + let entries = vec![EchoEntry::Success("line one\nline two".into())]; + let out = format_echo_message(&entries).expect("non-empty input → Some"); + assert_eq!(out, "> 🎤 line one line two"); + } + + use crate::adapter::{ChannelRef, ChatAdapter, MessageRef}; + use anyhow::Result; + use async_trait::async_trait; + use std::sync::{Arc, Mutex}; + + #[derive(Default)] + struct MockAdapter { + sent_messages: Mutex>, + reactions: Mutex>, + } + + #[async_trait] + impl ChatAdapter for MockAdapter { + fn platform(&self) -> &'static str { + "mock" + } + fn message_limit(&self) -> usize { + 4000 + } + async fn send_message(&self, channel: &ChannelRef, content: &str) -> Result { + self.sent_messages + .lock() + .unwrap() + .push((channel.clone(), content.to_string())); + Ok(MessageRef { + channel: channel.clone(), + message_id: "mock-msg".into(), + }) + } + async fn create_thread( + &self, + channel: &ChannelRef, + _trigger: &MessageRef, + _title: &str, + ) -> Result { + Ok(channel.clone()) + } + async fn add_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()> { + self.reactions + .lock() + .unwrap() + .push((msg.clone(), emoji.to_string())); + Ok(()) + } + async fn remove_reaction(&self, _msg: &MessageRef, _emoji: &str) -> Result<()> { + Ok(()) + } + fn use_streaming(&self, _other_bot_present: bool) -> bool { + false + } + } + + fn test_channel() -> ChannelRef { + ChannelRef { + platform: "mock".into(), + channel_id: "C1".into(), + thread_id: Some("T1".into()), + parent_id: None, + origin_event_id: None, + } + } + + fn test_trigger() -> MessageRef { + MessageRef { + channel: test_channel(), + message_id: "M1".into(), + } + } + + fn cfg(echo: bool) -> SttConfig { + SttConfig { + echo_transcript: echo, + ..SttConfig::default() + } + } + + #[tokio::test] + async fn post_echo_success_sends_one_message_no_reactions() { + let mock = Arc::new(MockAdapter::default()); + let adapter: Arc = mock.clone(); + let entries = vec![EchoEntry::Success("hello".into())]; + post_echo( + &adapter, + &test_channel(), + &test_trigger(), + &entries, + &cfg(true), + ) + .await; + + assert_eq!(mock.sent_messages.lock().unwrap().len(), 1); + assert_eq!(mock.sent_messages.lock().unwrap()[0].1, "> 🎤 hello"); + assert!(mock.reactions.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn post_echo_failure_adds_warning_reaction() { + let mock = Arc::new(MockAdapter::default()); + let adapter: Arc = mock.clone(); + let entries = vec![EchoEntry::Failed]; + post_echo( + &adapter, + &test_channel(), + &test_trigger(), + &entries, + &cfg(true), + ) + .await; + + assert_eq!(mock.sent_messages.lock().unwrap().len(), 1); + assert_eq!( + mock.sent_messages.lock().unwrap()[0].1, + "> 🎤 (transcription failed)" + ); + let reactions = mock.reactions.lock().unwrap(); + assert_eq!(reactions.len(), 1); + assert_eq!(reactions[0].1, "⚠️"); + } + + #[tokio::test] + async fn post_echo_mixed_one_message_one_reaction() { + let mock = Arc::new(MockAdapter::default()); + let adapter: Arc = mock.clone(); + let entries = vec![EchoEntry::Success("ok".into()), EchoEntry::Failed]; + post_echo( + &adapter, + &test_channel(), + &test_trigger(), + &entries, + &cfg(true), + ) + .await; + + assert_eq!(mock.sent_messages.lock().unwrap().len(), 1); + assert_eq!( + mock.sent_messages.lock().unwrap()[0].1, + "> 🎤 ok\n> 🎤 (transcription failed)" + ); + assert_eq!(mock.reactions.lock().unwrap().len(), 1); + } + + #[tokio::test] + async fn post_echo_disabled_is_noop() { + let mock = Arc::new(MockAdapter::default()); + let adapter: Arc = mock.clone(); + let entries = vec![EchoEntry::Success("hi".into()), EchoEntry::Failed]; + post_echo( + &adapter, + &test_channel(), + &test_trigger(), + &entries, + &cfg(false), + ) + .await; + + assert!(mock.sent_messages.lock().unwrap().is_empty()); + assert!(mock.reactions.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn post_echo_empty_entries_is_noop() { + let mock = Arc::new(MockAdapter::default()); + let adapter: Arc = mock.clone(); + let entries: Vec = vec![]; + post_echo( + &adapter, + &test_channel(), + &test_trigger(), + &entries, + &cfg(true), + ) + .await; + + assert!(mock.sent_messages.lock().unwrap().is_empty()); + assert!(mock.reactions.lock().unwrap().is_empty()); + } +} diff --git a/src/timestamp.rs b/src/timestamp.rs index e6c8d49f..aa7adce4 100644 --- a/src/timestamp.rs +++ b/src/timestamp.rs @@ -64,24 +64,36 @@ mod tests { #[test] fn slack_ts_keeps_milliseconds() { // 1714204397 = 2024-04-27T07:53:17 UTC; .123456 → .123 ms - assert_eq!(slack_ts_to_iso8601("1714204397.123456"), "2024-04-27T07:53:17.123Z"); + assert_eq!( + slack_ts_to_iso8601("1714204397.123456"), + "2024-04-27T07:53:17.123Z" + ); } #[test] fn slack_ts_missing_fraction_uses_zero() { - assert_eq!(slack_ts_to_iso8601("1714204397"), "2024-04-27T07:53:17.000Z"); + assert_eq!( + slack_ts_to_iso8601("1714204397"), + "2024-04-27T07:53:17.000Z" + ); } #[test] fn slack_ts_two_digit_fraction_is_120ms_not_12ms() { // ".12" carries decimal semantics: 0.12 s = 120 ms. - assert_eq!(slack_ts_to_iso8601("1714204397.12"), "2024-04-27T07:53:17.120Z"); + assert_eq!( + slack_ts_to_iso8601("1714204397.12"), + "2024-04-27T07:53:17.120Z" + ); } #[test] fn slack_ts_one_digit_fraction_is_100ms_not_1ms() { // ".1" carries decimal semantics: 0.1 s = 100 ms. - assert_eq!(slack_ts_to_iso8601("1714204397.1"), "2024-04-27T07:53:17.100Z"); + assert_eq!( + slack_ts_to_iso8601("1714204397.1"), + "2024-04-27T07:53:17.100Z" + ); } #[test]