Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: add remote agent auth tunneling via SSH agent protocol extensions
Implement proxying of auth requests back through SSH agent forwarding,
enabling users to SSH to a remote host and run `epithet agent fish` to
get certificates from an internal CA while auth plugins (FIDO2, browser)
run on the local machine.

Uses SSH agent protocol extensions (message type 27) with a registry
pattern for extensibility:
- epithet-hello@epithet.dev — probe if socket is an epithet agent
- epithet-auth@epithet.dev — request auth token from upstream

Key components:
- ProxyAgent: generic ExtendedAgent wrapper with extension handler registry
- ProxyListener: per-connection upstream dialing for agent protocol proxying
- Upstream client: ProbeUpstream() and RequestAuth() functions
- Broker: WithUpstream option and Authenticate() helper (try upstream, fall back to local)
- CLI: `epithet agent <shell>` wrapper mode mirroring ssh-agent's pattern

Multi-hop chaining works naturally — each layer proxies to its upstream.
No changes to CA or policy server required.
  • Loading branch information
brianm committed Apr 18, 2026
commit fded82ae5b1cbb14736274add19434e4bcb116af
301 changes: 226 additions & 75 deletions cmd/epithet/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"log/slog"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"

"github.com/epithet-ssh/epithet/pkg/agent"
"github.com/epithet-ssh/epithet/pkg/broker"
"github.com/epithet-ssh/epithet/pkg/caclient"
"github.com/epithet-ssh/epithet/pkg/tlsconfig"
Expand All @@ -32,74 +35,243 @@ type AgentCLI struct {
}

// AgentStartCLI is the default subcommand that starts the agent/broker.
type AgentStartCLI struct{}
// When Shell is provided, it runs in wrapper mode: wrapping a shell with an
// epithet-aware agent proxy (mirroring ssh-agent's "ssh-agent bash" pattern).
type AgentStartCLI struct {
Shell string `arg:"" optional:"" help:"Shell to wrap with epithet agent (enables wrapper mode)"`
}

func (s *AgentStartCLI) Run(parent *AgentCLI, logger *slog.Logger, tlsCfg tlsconfig.Config) error {
// Validate required fields for start
if s.Shell != "" {
return s.runWrapper(parent, logger, tlsCfg)
}
return s.runDaemon(parent, logger, tlsCfg)
}

// runDaemon runs the broker as a long-lived daemon (original behavior).
func (s *AgentStartCLI) runDaemon(parent *AgentCLI, logger *slog.Logger, tlsCfg tlsconfig.Config) error {
b, tempDir, brokerSock, agentDir, homeDir, err := setupBroker(parent, logger, tlsCfg)
if err != nil {
return err
}

// Clean up temp directory on exit.
defer func() {
if err := os.RemoveAll(tempDir); err != nil {
logger.Warn("failed to remove temp directory", "error", err, "path", tempDir)
} else {
logger.Debug("removed temp directory", "path", tempDir)
}
}()

// Set up context with cancellation on signals.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
logger.Info("received shutdown signal")
cancel()
}()

// Generate SSH config.
sshConfigPath := filepath.Join(tempDir, "ssh-config.conf")
if err := parent.generateSSHConfig(sshConfigPath, agentDir, brokerSock, homeDir); err != nil {
logger.Warn("failed to generate SSH config", "error", err, "path", sshConfigPath)
} else {
includePattern := filepath.Join(homeDir, ".epithet", "run", "*", "ssh-config.conf")
if err := checkSSHConfigInclude(homeDir, includePattern, logger); err != nil {
logger.Warn(fmt.Sprintf("Add 'Include %s' to ~/.ssh/config", includePattern))
}
logger.Debug("generated SSH config", "path", sshConfigPath)
}

// Start broker.
logger.Info("starting broker", "socket", brokerSock)
err = b.Serve(ctx)
if err != nil && err != context.Canceled {
return fmt.Errorf("broker serve error: %w", err)
}

logger.Info("broker shutdown complete")
return nil
}

// runWrapper wraps a shell with an epithet-aware agent proxy. It probes the
// current SSH_AUTH_SOCK for an upstream epithet agent, starts the broker, and
// creates a proxy listener that handles epithet protocol extensions while
// delegating standard SSH agent operations to the upstream agent.
func (s *AgentStartCLI) runWrapper(parent *AgentCLI, logger *slog.Logger, tlsCfg tlsconfig.Config) error {
upstreamSocket := os.Getenv("SSH_AUTH_SOCK")

// Probe for upstream epithet agent.
var depth int
var brokerOpts []broker.Option
if upstreamSocket != "" {
hello, err := agent.ProbeUpstream(upstreamSocket)
if err != nil {
logger.Warn("failed to probe upstream agent", "error", err)
} else if hello != nil {
depth = hello.ChainDepth + 1
brokerOpts = append(brokerOpts, broker.WithUpstream(upstreamSocket))
logger.Info("detected upstream epithet agent", "chain_depth", depth)
} else {
logger.Debug("upstream is vanilla ssh-agent")
}
} else {
logger.Debug("no SSH_AUTH_SOCK set")
}

b, tempDir, brokerSock, agentDir, homeDir, err := setupBrokerWithOptions(parent, logger, tlsCfg, brokerOpts...)
if err != nil {
return err
}

// Clean up temp directory on exit.
defer func() {
if err := os.RemoveAll(tempDir); err != nil {
logger.Warn("failed to remove temp directory", "error", err, "path", tempDir)
}
}()

// Set up context with cancellation.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Start broker in background.
brokerErr := make(chan error, 1)
go func() {
brokerErr <- b.Serve(ctx)
}()
<-b.Ready()

// Generate SSH config.
sshConfigPath := filepath.Join(tempDir, "ssh-config.conf")
if err := parent.generateSSHConfig(sshConfigPath, agentDir, brokerSock, homeDir); err != nil {
logger.Warn("failed to generate SSH config", "error", err, "path", sshConfigPath)
} else {
includePattern := filepath.Join(homeDir, ".epithet", "run", "*", "ssh-config.conf")
if err := checkSSHConfigInclude(homeDir, includePattern, logger); err != nil {
logger.Warn(fmt.Sprintf("Add 'Include %s' to ~/.ssh/config", includePattern))
}
}

// Create proxy listener if we have an upstream socket to proxy.
if upstreamSocket != "" {
proxySock := filepath.Join(tempDir, "proxy.sock")
setup := func(p *agent.ProxyAgent) {
p.RegisterExtension(agent.ExtensionHello, agent.HelloHandler(depth))
p.RegisterExtension(agent.ExtensionAuth, agent.AuthHandler(func() (string, error) {
return b.Authenticate(nil)
}))
}
proxyListener := agent.NewProxyListener(logger, proxySock, upstreamSocket, setup)
go func() {
if err := proxyListener.Serve(ctx); err != nil && err != context.Canceled {
logger.Error("proxy listener error", "error", err)
}
}()
// Wait briefly for the listener to start.
// TODO: add Ready() to ProxyListener like Broker has.
upstreamSocket = proxySock
logger.Info("proxy agent listening", "socket", proxySock)
}

// Build child environment with updated SSH_AUTH_SOCK.
childEnv := os.Environ()
if upstreamSocket != "" {
childEnv = replaceEnv(childEnv, "SSH_AUTH_SOCK", upstreamSocket)
}

// Run the shell as a child process.
cmd := exec.Command(s.Shell)
cmd.Env = childEnv
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

// Forward signals to child.
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
for sig := range sigChan {
if cmd.Process != nil {
cmd.Process.Signal(sig)
}
}
}()

logger.Info("starting shell", "shell", s.Shell)
if err := cmd.Run(); err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
cancel()
signal.Stop(sigChan)
os.Exit(exitErr.ExitCode())
}
cancel()
return fmt.Errorf("shell failed: %w", err)
}

cancel()
signal.Stop(sigChan)
return nil
}

// setupBroker creates the broker and supporting directories with no extra options.
func setupBroker(parent *AgentCLI, logger *slog.Logger, tlsCfg tlsconfig.Config) (*broker.Broker, string, string, string, string, error) {
return setupBrokerWithOptions(parent, logger, tlsCfg)
}

// setupBrokerWithOptions creates the broker, temp directories, and resolves auth.
// Returns (broker, tempDir, brokerSock, agentDir, homeDir, error).
func setupBrokerWithOptions(parent *AgentCLI, logger *slog.Logger, tlsCfg tlsconfig.Config, opts ...broker.Option) (*broker.Broker, string, string, string, string, error) {
if len(parent.CaURL) == 0 {
return fmt.Errorf("--ca-url is required (at least one)")
return nil, "", "", "", "", fmt.Errorf("--ca-url is required (at least one)")
}

// Parse CA URLs into endpoints
caEndpoints, err := caclient.ParseCAURLs(parent.CaURL)
if err != nil {
return fmt.Errorf("invalid CA URL: %w", err)
return nil, "", "", "", "", fmt.Errorf("invalid CA URL: %w", err)
}

logger.Debug("agent start command received", "ca_urls", parent.CaURL, "ca_timeout", parent.CaTimeout, "ca_cooldown", parent.CaCooldown)
logger.Debug("agent command received", "ca_urls", parent.CaURL, "ca_timeout", parent.CaTimeout, "ca_cooldown", parent.CaCooldown)

// Validate all CA URLs require TLS (unless --insecure)
for _, ep := range caEndpoints {
if err := tlsCfg.ValidateURL(ep.URL); err != nil {
return fmt.Errorf("CA URL %q: %w", ep.URL, err)
return nil, "", "", "", "", fmt.Errorf("CA URL %q: %w", ep.URL, err)
}
}

// Get home directory
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get home directory: %w", err)
return nil, "", "", "", "", fmt.Errorf("failed to get home directory: %w", err)
}

// Create a unique temporary directory for this broker instance
// Use a hash of the CA URLs to make it deterministic
instanceID := hashString(fmt.Sprintf("%v", parent.CaURL))
runDir := filepath.Join(homeDir, ".epithet", "run")
tempDir := filepath.Join(runDir, instanceID)

// Clean up stale run directories from dead processes
cleanupStaleRunDirs(runDir, logger)

// Clean up temp directory on exit
defer func() {
if err := os.RemoveAll(tempDir); err != nil {
logger.Warn("failed to remove temp directory", "error", err, "path", tempDir)
} else {
logger.Debug("removed temp directory", "path", tempDir)
}
}()

// Create temp directory
if err := os.MkdirAll(tempDir, 0700); err != nil {
return fmt.Errorf("failed to create temp directory: %w", err)
return nil, "", "", "", "", fmt.Errorf("failed to create temp directory: %w", err)
}

// Write PID file for stale detection
pidFile := filepath.Join(tempDir, "broker.pid")
if err := os.WriteFile(pidFile, []byte(strconv.Itoa(os.Getpid())), 0600); err != nil {
return fmt.Errorf("failed to write PID file: %w", err)
return nil, "", "", "", "", fmt.Errorf("failed to write PID file: %w", err)
}

// Define paths within temp directory
brokerSock := filepath.Join(tempDir, "broker.sock")
agentDir := filepath.Join(tempDir, "agent")

// Create agent directory
if err := os.MkdirAll(agentDir, 0700); err != nil {
return fmt.Errorf("failed to create agent directory: %w", err)
return nil, "", "", "", "", fmt.Errorf("failed to create agent directory: %w", err)
}

// Create CA client
caClientOpts := []caclient.Option{
caclient.WithLogger(logger),
caclient.WithTimeout(parent.CaTimeout),
Expand All @@ -108,81 +280,60 @@ func (s *AgentStartCLI) Run(parent *AgentCLI, logger *slog.Logger, tlsCfg tlscon
}
caClient, err := caclient.New(caEndpoints, caClientOpts...)
if err != nil {
return fmt.Errorf("failed to create CA client: %w", err)
return nil, "", "", "", "", fmt.Errorf("failed to create CA client: %w", err)
}

// Resolve auth command: use --auth if provided, otherwise discover from CA.
authCommand := parent.Auth
if authCommand == "" {
logger.Debug("no --auth provided, discovering from CA")

// Fetch public key (also caches discovery URL).
_, err := caClient.GetPublicKey(context.Background())
if err != nil {
return fmt.Errorf("failed to get CA public key: %w", err)
return nil, "", "", "", "", fmt.Errorf("failed to get CA public key: %w", err)
}

// Fetch unauthenticated discovery (auth config only, no token needed).
discovery, err := caClient.GetDiscovery(context.Background(), "")
if err != nil {
return fmt.Errorf("failed to get discovery config (try --auth to specify manually): %w", err)
return nil, "", "", "", "", fmt.Errorf("failed to get discovery config (try --auth to specify manually): %w", err)
}
if discovery == nil || discovery.Auth == nil {
return fmt.Errorf("CA did not return auth config in discovery (try --auth to specify manually)")
return nil, "", "", "", "", fmt.Errorf("CA did not return auth config in discovery (try --auth to specify manually)")
}

// Convert auth config to command.
authCommand, err = broker.AuthConfigToCommand(*discovery.Auth)
if err != nil {
return fmt.Errorf("failed to convert discovery auth config: %w", err)
return nil, "", "", "", "", fmt.Errorf("failed to convert discovery auth config: %w", err)
}

logger.Info("discovered auth config from CA", "type", discovery.Auth.Type, "issuer", discovery.Auth.Issuer)
}

// Create broker
b, err := broker.New(*logger, brokerSock, authCommand, caClient, agentDir)
b, err := broker.New(*logger, brokerSock, authCommand, caClient, agentDir, opts...)
if err != nil {
return fmt.Errorf("failed to create broker: %w", err)
return nil, "", "", "", "", fmt.Errorf("failed to create broker: %w", err)
}

// Set up context with cancellation on signals
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigChan
logger.Info("received shutdown signal")
cancel()
}()

// Generate SSH config file in the temp directory
sshConfigPath := filepath.Join(tempDir, "ssh-config.conf")

if err := parent.generateSSHConfig(sshConfigPath, agentDir, brokerSock, homeDir); err != nil {
logger.Warn("failed to generate SSH config", "error", err, "path", sshConfigPath)
// Don't fail startup, just warn
} else {
// Check if ~/.ssh/config has the Include directive
includePattern := filepath.Join(homeDir, ".epithet", "run", "*", "ssh-config.conf")
if err := checkSSHConfigInclude(homeDir, includePattern, logger); err != nil {
return b, tempDir, brokerSock, agentDir, homeDir, nil
}

logger.Warn(fmt.Sprintf("Add 'Include %s' to ~/.ssh/config", includePattern))
// replaceEnv returns a copy of environ with key set to value.
// If key already exists, it is replaced; otherwise it is appended.
func replaceEnv(environ []string, key, value string) []string {
prefix := key + "="
result := make([]string, 0, len(environ)+1)
found := false
for _, e := range environ {
if strings.HasPrefix(e, prefix) {
result = append(result, prefix+value)
found = true
} else {
result = append(result, e)
}
logger.Debug("generated SSH config", "path", sshConfigPath)
}

// Start broker
logger.Info("starting broker", "socket", brokerSock)
err = b.Serve(ctx)
if err != nil && err != context.Canceled {
return fmt.Errorf("broker serve error: %w", err)
if !found {
result = append(result, prefix+value)
}

logger.Info("broker shutdown complete")
return nil
return result
}

// generateSSHConfig writes an SSH config file for epithet
Expand Down
Loading
Loading