Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update dynamic rules
  • Loading branch information
atasmohammadi committed Jun 12, 2025
commit b863e0ebfc11085e76cc53c41e2a7c10ca6f0a33
71 changes: 41 additions & 30 deletions internal/firewall/enable.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,36 +41,6 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
return nil
}

func (c *Config) disable(ctx context.Context) (err error) {
if err = c.clearAllRules(ctx); err != nil {
return fmt.Errorf("clearing all rules: %w", err)
}
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv4 policies: %w", err)
}
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv6 policies: %w", err)
}

const remove = true
err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("removing port redirections: %w", err)
}

return nil
}

// To use in defered call when enabling the firewall.
func (c *Config) fallbackToDisabled(ctx context.Context) {
if ctx.Err() != nil {
return
}
if err := c.disable(ctx); err != nil {
c.logger.Error("failed reversing firewall changes: " + err.Error())
}
}

func (c *Config) enable(ctx context.Context) (err error) {
touched := false
if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil {
Expand All @@ -90,6 +60,11 @@ func (c *Config) enable(ctx context.Context) (err error) {
}
}()

// Clear any previously applied post-rules
if err = c.clearAppliedPostRules(ctx); err != nil {
c.logger.Warn("failed to clear previous post-rules: " + err.Error())
}

Comment on lines +63 to 67
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not needed, given the firewall is only enabled once on a clear set of rules. Plus disable does remove all rules anyway, in case the firewall would get disabled then re-enabled.

// Loopback traffic
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
return err
Expand Down Expand Up @@ -144,13 +119,49 @@ func (c *Config) enable(ctx context.Context) (err error) {
return fmt.Errorf("redirecting ports: %w", err)
}

// Apply post-rules only once at the end
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
return fmt.Errorf("running user defined post firewall rules: %w", err)
}

return nil
}

func (c *Config) disable(ctx context.Context) (err error) {
// Clear applied post-rules when disabling
if err = c.clearAppliedPostRules(ctx); err != nil {
c.logger.Warn("failed to clear post-rules during disable: " + err.Error())
}
Comment on lines +131 to +134
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed because of existing clearAllRules


if err = c.clearAllRules(ctx); err != nil {
return fmt.Errorf("clearing all rules: %w", err)
}
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv4 policies: %w", err)
}
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv6 policies: %w", err)
}

const remove = true
err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("removing port redirections: %w", err)
}

return nil
}

// To use in defered call when enabling the firewall.
func (c *Config) fallbackToDisabled(ctx context.Context) {
if ctx.Err() != nil {
return
}
if err := c.disable(ctx); err != nil {
c.logger.Error("failed reversing firewall changes: " + err.Error())
}
}

Comment on lines +130 to +164
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move these blocks where they previously were to see diffs more clearly.

func (c *Config) allowVPNIP(ctx context.Context) (err error) {
if !c.vpnConnection.IP.IsValid() {
return nil
Expand Down
35 changes: 29 additions & 6 deletions internal/firewall/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package firewall
import (
"context"
"net/netip"
"strings"
"sync"

"github.com/qdm12/gluetun/internal/models"
Expand All @@ -29,15 +30,10 @@ type Config struct { //nolint:maligned
outboundSubnets []netip.Prefix
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
portRedirections portRedirections
appliedPostRules []string // Track applied post-rules to avoid duplicates
stateMutex sync.Mutex
}

// applyUserPostRules applies user-defined post firewall rules
func (c *Config) applyUserPostRules(ctx context.Context) error {
const remove = false
return c.runUserPostRules(ctx, c.customRulesPath, remove)
}

// NewConfig creates a new Config instance and returns an error
// if no iptables implementation is available.
func NewConfig(ctx context.Context, logger Logger,
Expand Down Expand Up @@ -66,3 +62,30 @@ func NewConfig(ctx context.Context, logger Logger,
localNetworks: localNetworks,
}, nil
}

// clearAppliedPostRules removes all previously applied post-rules
func (c *Config) clearAppliedPostRules(ctx context.Context) error {
for _, rule := range c.appliedPostRules {
flippedRule := flipRule(rule)
if strings.Contains(rule, "ip6tables") {
if err := c.runIP6tablesInstruction(ctx, flippedRule); err != nil {
c.logger.Debug("failed to remove post-rule (may not exist): " + err.Error())
}
} else {
if err := c.runIptablesInstruction(ctx, flippedRule); err != nil {
c.logger.Debug("failed to remove post-rule (may not exist): " + err.Error())
}
}
}
c.appliedPostRules = nil
return nil
}

// applyPostRulesOnce applies post-rules only if they haven't been applied yet
func (c *Config) applyPostRulesOnce(ctx context.Context) error {
if len(c.appliedPostRules) > 0 {
c.logger.Debug("post-rules already applied, skipping")
return nil
}
return c.runUserPostRules(ctx, c.customRulesPath, false)
}
30 changes: 21 additions & 9 deletions internal/firewall/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,6 @@ func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove b
} else if err != nil {
return err
}

// Log when post-rules are being applied
if !remove {
c.logger.Info("applying user-defined post firewall rules from " + filepath)
} else {
c.logger.Info("removing user-defined post firewall rules from " + filepath)
}


b, err := io.ReadAll(file)
if err != nil {
_ = file.Close()
Expand All @@ -292,6 +283,7 @@ func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove b
_ = c.runIptablesInstruction(ctx, flipRule(rule))
}
}()

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit unneeded newline change here

Suggested change

for _, line := range lines {
var ipv4 bool
var rule string
Expand Down Expand Up @@ -322,6 +314,21 @@ func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove b
rule = flipRule(rule)
}

// Check if this rule was already applied (avoid duplicates)
if !remove {
Comment on lines +317 to +318
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we track duplicates? That should be up to the user to ensure they don't have duplicate iptables commands

ruleAlreadyApplied := false
for _, appliedRule := range c.appliedPostRules {
if appliedRule == line {
ruleAlreadyApplied = true
break
}
}
if ruleAlreadyApplied {
c.logger.Debug("skipping duplicate post-rule: " + line)
continue
}
}

switch {
case ipv4:
err = c.runIptablesInstruction(ctx, rule)
Expand All @@ -335,6 +342,11 @@ func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove b
}

successfulRules = append(successfulRules, rule)

// Track applied rules (only when adding, not removing)
if !remove {
c.appliedPostRules = append(c.appliedPostRules, line)
}
}
return nil
}
71 changes: 41 additions & 30 deletions internal/firewall/ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,47 @@ import (
)

func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()

interfaceSet, ok := c.allowedInputPorts[port]
if !ok {
interfaceSet = make(map[string]struct{})
c.allowedInputPorts[port] = interfaceSet
}

_, alreadySet := interfaceSet[intf]
if alreadySet {
return nil
}

if c.enabled {
const remove = false
err = c.acceptInputToPort(ctx, intf, port, remove)
if err != nil {
return fmt.Errorf("accepting input port %d on interface %s: %w",
port, intf, err)
}

// ADD THIS: Re-apply user post-rules after port changes
if err = c.applyUserPostRules(ctx); err != nil {
return fmt.Errorf("re-applying user post-rules after port change: %w", err)
}
}

interfaceSet[intf] = struct{}{}
return nil
c.stateMutex.Lock()
defer c.stateMutex.Unlock()

if port == 0 {
return nil
}

if !c.enabled {
c.logger.Info("firewall disabled, only updating allowed ports internal state")
existingInterfaces, ok := c.allowedInputPorts[port]
if !ok {
existingInterfaces = make(map[string]struct{})
}
existingInterfaces[intf] = struct{}{}
c.allowedInputPorts[port] = existingInterfaces
return nil
}

netInterfaces, has := c.allowedInputPorts[port]
if !has {
netInterfaces = make(map[string]struct{})
} else if _, exists := netInterfaces[intf]; exists {
return nil
}

c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...")

const remove = false
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
return fmt.Errorf("allowing input to port %d through interface %s: %w",
port, intf, err)
}
netInterfaces[intf] = struct{}{}
c.allowedInputPorts[port] = netInterfaces

// Apply post-rules only once after adding the port, and only if not already applied
if err := c.applyPostRulesOnce(ctx); err != nil {
c.logger.Warn("failed to apply post-rules after adding port: " + err.Error())
}

Comment on lines +45 to +49
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this different from just running the post rules once at firewall enabling? Especially given applyPostRulesOnce is not even aware of the newly allowed port.

return nil
}

func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error) {
Expand Down