166 lines
3.8 KiB
Go
166 lines
3.8 KiB
Go
package cardigann
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// ValidateURL validates that a URL is safe to make requests to.
|
|
// It blocks requests to private/internal IPs and non-HTTP schemes.
|
|
// Threat model T-10-05: SSRF protection.
|
|
func ValidateURL(rawURL string) error {
|
|
// Check for config override (testing only)
|
|
if os.Getenv("CARDIGANN_ALLOW_PRIVATE") == "true" {
|
|
return nil
|
|
}
|
|
|
|
// Basic scheme check before full URL parsing
|
|
lower := strings.ToLower(rawURL)
|
|
if !strings.HasPrefix(lower, "http://") && !strings.HasPrefix(lower, "https://") {
|
|
return fmt.Errorf("URL scheme must be http or https, got: %q", rawURL)
|
|
}
|
|
|
|
// Extract hostname
|
|
host := rawURL
|
|
// Remove scheme
|
|
if idx := strings.Index(host, "://"); idx >= 0 {
|
|
host = host[idx+3:]
|
|
}
|
|
// Remove path and everything after
|
|
if idx := strings.Index(host, "/"); idx >= 0 {
|
|
host = host[:idx]
|
|
}
|
|
// Remove port
|
|
if idx := strings.LastIndex(host, ":"); idx >= 0 {
|
|
host = host[:idx]
|
|
}
|
|
// Remove user info
|
|
if idx := strings.LastIndex(host, "@"); idx >= 0 {
|
|
host = host[idx+1:]
|
|
}
|
|
|
|
host = strings.ToLower(strings.TrimSpace(host))
|
|
|
|
// Block well-known local hostnames
|
|
if host == "localhost" || strings.HasSuffix(host, ".local") || strings.HasSuffix(host, ".internal") {
|
|
return fmt.Errorf("hostname %q is blocked (private/local)", host)
|
|
}
|
|
|
|
// Resolve hostname and check IPs
|
|
resolveCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
resolver := net.Resolver{}
|
|
ips, err := resolver.LookupIPAddr(resolveCtx, host)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to resolve hostname %q: %w", host, err)
|
|
}
|
|
|
|
for _, ipAddr := range ips {
|
|
ip := ipAddr.IP
|
|
if isPrivateIP(ip) {
|
|
return fmt.Errorf("hostname %q resolves to private IP %s", host, ip)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// isPrivateIP checks if an IP address is in a private/reserved range.
|
|
func isPrivateIP(ip net.IP) bool {
|
|
// IPv4 private ranges
|
|
if ip.To4() != nil {
|
|
// 127.0.0.0/8 (loopback)
|
|
if ip.IsLoopback() {
|
|
return true
|
|
}
|
|
// 10.0.0.0/8
|
|
if ip[0] == 10 {
|
|
return true
|
|
}
|
|
// 172.16.0.0/12
|
|
if ip[0] == 172 && ip[1] >= 16 && ip[1] <= 31 {
|
|
return true
|
|
}
|
|
// 192.168.0.0/16
|
|
if ip[0] == 192 && ip[1] == 168 {
|
|
return true
|
|
}
|
|
// 169.254.0.0/16 (link-local)
|
|
if ip[0] == 169 && ip[1] == 254 {
|
|
return true
|
|
}
|
|
// 0.0.0.0
|
|
if ip.IsUnspecified() {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// IPv6 checks
|
|
if ip.To4() == nil {
|
|
// ::1 (loopback)
|
|
if ip.IsLoopback() {
|
|
return true
|
|
}
|
|
// fc00::/7 (unique local / private)
|
|
if (ip[0] & 0xfe) == 0xfc {
|
|
return true
|
|
}
|
|
// fe80::/10 (link-local)
|
|
if ip[0] == 0xfe && (ip[1]&0xc0) == 0x80 {
|
|
return true
|
|
}
|
|
// :: (unspecified)
|
|
if ip.IsUnspecified() {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// SafeHTTPClient returns an http.Client with timeouts and DNS checking.
|
|
func SafeHTTPClient() *http.Client {
|
|
return &http.Client{
|
|
Timeout: 15 * time.Second,
|
|
Transport: &http.Transport{
|
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
// Extract host from addr (may include port)
|
|
host, _, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
host = addr
|
|
}
|
|
|
|
// Resolve and check the IP
|
|
resolver := net.Resolver{}
|
|
ips, err := resolver.LookupIPAddr(ctx, host)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("DNS resolution failed for %q: %w", host, err)
|
|
}
|
|
|
|
for _, ipAddr := range ips {
|
|
if isPrivateIP(ipAddr.IP) {
|
|
return nil, fmt.Errorf("blocked private IP %s for host %q", ipAddr.IP, host)
|
|
}
|
|
}
|
|
|
|
// Use the first resolved IP
|
|
if len(ips) == 0 {
|
|
return nil, fmt.Errorf("no IP addresses found for %q", host)
|
|
}
|
|
|
|
dialer := net.Dialer{Timeout: 10 * time.Second}
|
|
return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), func() string {
|
|
_, port, _ := net.SplitHostPort(addr)
|
|
return port
|
|
}()))
|
|
},
|
|
},
|
|
}
|
|
}
|