From 2a38a1f462a602cfec1ace6d2477934bcf2a9522 Mon Sep 17 00:00:00 2001 From: kyriediculous Date: Fri, 20 Sep 2024 15:56:41 +0800 Subject: [PATCH] api key authentication --- db/supabase.go | 34 ++++++++++++++++++++++++++++++++++ go.mod | 6 ++++++ go.sum | 12 ++++++++++++ middleware/auth.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ openai-api.go | 14 +++++++++++++- server/http.go | 14 +++++++++++--- 6 files changed, 121 insertions(+), 4 deletions(-) create mode 100644 db/supabase.go create mode 100644 middleware/auth.go diff --git a/db/supabase.go b/db/supabase.go new file mode 100644 index 0000000..b50d53e --- /dev/null +++ b/db/supabase.go @@ -0,0 +1,34 @@ +// database/supabase.go + +package db + +import ( + "fmt" + + "github.com/supabase-community/supabase-go" +) + +type SupabaseAPIKeyStore struct { + client *supabase.Client +} + +func NewSupabaseAPIKeyStore(supabaseURL, supabaseKey string) (*SupabaseAPIKeyStore, error) { + client, err := supabase.NewClient(supabaseURL, supabaseKey, nil) + if err != nil { + return nil, fmt.Errorf("cannot initialize Supabase client: %w", err) + } + return &SupabaseAPIKeyStore{client: client}, nil +} + +func (s *SupabaseAPIKeyStore) ValidateAPIKey(apiKey string) (bool, error) { + _, count, err := s.client.From("api_keys"). + Select("id", "exact", false). + Eq("key", apiKey). + Execute() + + if err != nil { + return false, fmt.Errorf("error checking API key: %w", err) + } + + return count > 0, nil +} diff --git a/go.mod b/go.mod index 3a5a562..5a96da3 100644 --- a/go.mod +++ b/go.mod @@ -52,6 +52,12 @@ require ( github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect + github.com/supabase-community/functions-go v0.0.0-20220927045802-22373e6cb51d // indirect + github.com/supabase-community/gotrue-go v1.2.0 // indirect + github.com/supabase-community/postgrest-go v0.0.11 // indirect + github.com/supabase-community/storage-go v0.7.0 // indirect + github.com/supabase-community/supabase-go v0.0.4 // indirect + github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect github.com/vincent-petithory/dataurl v1.0.0 // indirect diff --git a/go.sum b/go.sum index e370b71..19aaded 100644 --- a/go.sum +++ b/go.sum @@ -134,6 +134,18 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/supabase-community/functions-go v0.0.0-20220927045802-22373e6cb51d h1:LOrsumaZy615ai37h9RjUIygpSubX+F+6rDct1LIag0= +github.com/supabase-community/functions-go v0.0.0-20220927045802-22373e6cb51d/go.mod h1:nnIju6x3+OZSojtGQCQzu0h3kv4HdIZk+UWCnNxtSak= +github.com/supabase-community/gotrue-go v1.2.0 h1:Zm7T5q3qbuwPgC6xyomOBKrSb7X5dvmjDZEmNST7MoE= +github.com/supabase-community/gotrue-go v1.2.0/go.mod h1:86DXBiAUNcbCfgbeOPEh0PQxScLfowUbYgakETSFQOw= +github.com/supabase-community/postgrest-go v0.0.11 h1:717GTUMfLJxSBuAeEQG2MuW5Q62Id+YrDjvjprTSErg= +github.com/supabase-community/postgrest-go v0.0.11/go.mod h1:cw6LfzMyK42AOSBA1bQ/HZ381trIJyuui2GWhraW7Cc= +github.com/supabase-community/storage-go v0.7.0 h1:cJ8HLbbnL54H5rHPtHfiwtpRwcbDfA3in9HL/ucHnqA= +github.com/supabase-community/storage-go v0.7.0/go.mod h1:oBKcJf5rcUXy3Uj9eS5wR6mvpwbmvkjOtAA+4tGcdvQ= +github.com/supabase-community/supabase-go v0.0.4 h1:sxMenbq6N8a3z9ihNpN3lC2FL3E1YuTQsjX09VPRp+U= +github.com/supabase-community/supabase-go v0.0.4/go.mod h1:SSHsXoOlc+sq8XeXaf0D3gE2pwrq5bcUfzm0+08u/o8= +github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y= +github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= diff --git a/middleware/auth.go b/middleware/auth.go new file mode 100644 index 0000000..9dd313d --- /dev/null +++ b/middleware/auth.go @@ -0,0 +1,45 @@ +// middleware/auth.go + +package middleware + +import ( + "github.com/gin-gonic/gin" +) + +type APIKeyStore interface { + ValidateAPIKey(apiKey string) (bool, error) +} + +type Auth struct { + apiKeyStore APIKeyStore +} + +func NewAuthMiddleware(apiKeyStore APIKeyStore) *Auth { + return &Auth{apiKeyStore: apiKeyStore} +} + +func (a *Auth) AuthRequired() gin.HandlerFunc { + return func(c *gin.Context) { + apiKey := c.GetHeader("X-API-Key") + if apiKey == "" { + c.JSON(401, gin.H{"error": "API key is required"}) + c.Abort() + return + } + + valid, err := a.apiKeyStore.ValidateAPIKey(apiKey) + if err != nil { + c.JSON(500, gin.H{"error": "Error checking API key"}) + c.Abort() + return + } + + if !valid { + c.JSON(401, gin.H{"error": "Invalid API key"}) + c.Abort() + return + } + + c.Next() + } +} diff --git a/openai-api.go b/openai-api.go index c0a9b3b..649e74b 100644 --- a/openai-api.go +++ b/openai-api.go @@ -4,6 +4,7 @@ import ( "flag" "log" + "github.com/livepool-io/openai-middleware/db" "github.com/livepool-io/openai-middleware/middleware" "github.com/livepool-io/openai-middleware/server" ) @@ -11,9 +12,20 @@ import ( func main() { gatewayURL := flag.String("gateway", "http://your-api-host", "The URL of the gateway API") port := flag.String("port", "8080", "The port to run the server on") + dbURL := flag.String("db-url", "http://your-db-host", "The URL of the database") + dbKey := flag.String("db-key", "your-db-key", "The key to access the database") + flag.Parse() + + apiKeyStore, err := db.NewSupabaseAPIKeyStore(*dbURL, *dbKey) + if err != nil { + log.Fatalf("Failed to create Supabase API key store: %v", err) + } + gateway := middleware.NewGateway(*gatewayURL) - server, err := server.NewServer(gateway) + auth := middleware.NewAuthMiddleware(apiKeyStore) + + server, err := server.NewServer(auth, gateway) if err != nil { log.Fatalf("Failed to create server: %v", err) } diff --git a/server/http.go b/server/http.go index 29afc2c..d573ab7 100644 --- a/server/http.go +++ b/server/http.go @@ -14,22 +14,30 @@ import ( type Server struct { // We can add fields here later if needed, such as a database connection + auth *middleware.Auth gateway *middleware.Gateway } -func NewServer(gateway *middleware.Gateway) (*Server, error) { +func NewServer(auth *middleware.Auth, gateway *middleware.Gateway) (*Server, error) { if gateway == nil { return nil, fmt.Errorf("gateway is required") } + return &Server{ gateway: gateway, + auth: auth, }, nil } func (s *Server) Start(port string) error { r := gin.Default() - r.POST("/v1/chat/completions", s.handleChatCompletion) - r.POST("/v1/completions", s.handleCompletion) + // Apply auth middleware to protected routes + protected := r.Group("/") + protected.Use(s.auth.AuthRequired()) + { + protected.POST("/v1/chat/completions", s.handleChatCompletion) + protected.POST("/v1/completions", s.handleCompletion) + } r.GET("/v1/models", s.handleModels) r.GET("/v1/embeddings", s.handleEmbeddings) r.GET("/health", s.handleHealth)