Files
zeroclaw/src/providers/openai.rs
argenis de la rosa 976c5bbf3c hardening: fix 7 production weaknesses found in codebase scan
Scan findings and fixes:

1. Gateway buffer overflow (8KB → 64KB)
   - Fixed: Increased request buffer from 8,192 to 65,536 bytes
   - Large POST bodies (long prompts) were silently truncated

2. Gateway slow-loris attack (no read timeout → 30s)
   - Fixed: tokio::time::timeout(30s) on stream.read()
   - Malicious clients could hold connections indefinitely

3. Webhook secret timing attack (== → constant_time_eq)
   - Fixed: Now uses constant_time_eq() for secret comparison
   - Prevents timing side-channel on webhook authentication

4. Pairing brute force (no limit → 5 attempts + 5min lockout)
   - Fixed: PairingGuard tracks failed attempts with lockout
   - Returns 429 Too Many Requests with retry_after seconds

5. Shell tool hang (no timeout → 60s kill)
   - Fixed: tokio::time::timeout(60s) on Command::output()
   - Commands that hang are killed and return error

6. Shell tool OOM (unbounded output → 1MB cap)
   - Fixed: stdout/stderr truncated at 1MB with warning
   - Prevents memory exhaustion from verbose commands

7. Provider HTTP timeout (none → 120s request + 10s connect)
   - Fixed: All 5 providers (OpenRouter, Anthropic, OpenAI,
     Ollama, Compatible) now have reqwest timeouts
   - Ollama gets 300s (local models are slower)

949 tests passing, 0 clippy warnings, cargo fmt clean
2026-02-14 01:47:08 -05:00

223 lines
6.2 KiB
Rust

use crate::providers::traits::Provider;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct OpenAiProvider {
api_key: Option<String>,
client: Client,
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
temperature: f64,
}
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
content: String,
}
impl OpenAiProvider {
pub fn new(api_key: Option<&str>) -> Self {
Self {
api_key: api_key.map(ToString::to_string),
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
}
}
}
#[async_trait]
impl Provider for OpenAiProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?;
let mut messages = Vec::new();
if let Some(sys) = system_prompt {
messages.push(Message {
role: "system".to_string(),
content: sys.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: message.to_string(),
});
let request = ChatRequest {
model: model.to_string(),
messages,
temperature,
};
let response = self
.client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!("OpenAI API error: {error}");
}
let chat_response: ChatResponse = response.json().await?;
chat_response
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn creates_with_key() {
let p = OpenAiProvider::new(Some("sk-proj-abc123"));
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
}
#[test]
fn creates_without_key() {
let p = OpenAiProvider::new(None);
assert!(p.api_key.is_none());
}
#[test]
fn creates_with_empty_key() {
let p = OpenAiProvider::new(Some(""));
assert_eq!(p.api_key.as_deref(), Some(""));
}
#[tokio::test]
async fn chat_fails_without_key() {
let p = OpenAiProvider::new(None);
let result = p.chat_with_system(None, "hello", "gpt-4o", 0.7).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("API key not set"));
}
#[tokio::test]
async fn chat_with_system_fails_without_key() {
let p = OpenAiProvider::new(None);
let result = p
.chat_with_system(Some("You are ZeroClaw"), "test", "gpt-4o", 0.5)
.await;
assert!(result.is_err());
}
#[test]
fn request_serializes_with_system_message() {
let req = ChatRequest {
model: "gpt-4o".to_string(),
messages: vec![
Message {
role: "system".to_string(),
content: "You are ZeroClaw".to_string(),
},
Message {
role: "user".to_string(),
content: "hello".to_string(),
},
],
temperature: 0.7,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"role\":\"system\""));
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("gpt-4o"));
}
#[test]
fn request_serializes_without_system() {
let req = ChatRequest {
model: "gpt-4o".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: "hello".to_string(),
}],
temperature: 0.0,
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("system"));
assert!(json.contains("\"temperature\":0.0"));
}
#[test]
fn response_deserializes_single_choice() {
let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices.len(), 1);
assert_eq!(resp.choices[0].message.content, "Hi!");
}
#[test]
fn response_deserializes_empty_choices() {
let json = r#"{"choices":[]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.choices.is_empty());
}
#[test]
fn response_deserializes_multiple_choices() {
let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices.len(), 2);
assert_eq!(resp.choices[0].message.content, "A");
}
#[test]
fn response_with_unicode() {
let json = r#"{"choices":[{"message":{"content":"こんにちは 🦀"}}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices[0].message.content, "こんにちは 🦀");
}
#[test]
fn response_with_long_content() {
let long = "x".repeat(100_000);
let json = format!(r#"{{"choices":[{{"message":{{"content":"{long}"}}}}]}}"#);
let resp: ChatResponse = serde_json::from_str(&json).unwrap();
assert_eq!(resp.choices[0].message.content.len(), 100_000);
}
}