diff --git a/.gitignore b/.gitignore index 8039089..ff5deec 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ # Built binary openai-api +# env +.env + # Binaries for programs and plugins *.exe *.exe~ diff --git a/Dockerfile b/Dockerfile index d9cf04b..e4829f4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,14 +4,16 @@ FROM golang:1.21 AS builder # Set the working directory inside the container WORKDIR /app -# Copy from local -COPY . . +COPY . ./ # Download all the dependencies RUN go mod download +# Generate the Prisma Client Go client +RUN go generate ./db + # Build the Go application with CGO disabled and statically linked -RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o app . +RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o api . # Use a minimal base image for running the application FROM alpine:latest @@ -28,4 +30,4 @@ COPY --from=builder /app/app . EXPOSE 8080 # Set the entry point to run the binary -ENTRYPOINT ["./app"] \ No newline at end of file +ENTRYPOINT ["./api"] \ No newline at end of file diff --git a/Makefile b/Makefile index e1436df..428106b 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,9 @@ # Default target -all: build +all: db build # Build the Go application build: - go build openai-api.go \ No newline at end of file + go build openai-api.go + +db: + go generate ./db \ No newline at end of file diff --git a/README.md b/README.md index 8e4abab..8edd3d5 100644 --- a/README.md +++ b/README.md @@ -189,3 +189,23 @@ type LlmStreamChunk struct { Done bool `json:"done,omitempty"` } ``` + +## Prisma & Postgres + +1. Install the auto-generated query builder for go + +```sh +go get github.com/steebchen/prisma-client-go +``` + +To generate the schema from an existing database, specify the DB source in our `schema.prisma` file and pull the schema + +```sh +go run github.com/steebchen/prisma-client-go db pull +``` + +Once we have our schema we can generate our client bindings: + +```sh +make db +``` diff --git a/db/.gitignore b/db/.gitignore new file mode 100644 index 0000000..a0c7514 --- /dev/null +++ b/db/.gitignore @@ -0,0 +1,2 @@ +# gitignore generated by Prisma Client Go. DO NOT EDIT. +*_gen.go diff --git a/db/db.go b/db/db.go new file mode 100644 index 0000000..ba11c9e --- /dev/null +++ b/db/db.go @@ -0,0 +1,146 @@ +//go:generate go run github.com/steebchen/prisma-client-go generate + +package db + +import ( + "context" + "errors" + "time" +) + +type APIKeyStore struct { + client *PrismaClient +} + +func NewAPIKeyStore() (*APIKeyStore, error) { + client := NewClient() + if err := client.Prisma.Connect(); err != nil { + return nil, err + } + + return &APIKeyStore{ + client: client, + }, nil +} + +func (s *APIKeyStore) ValidateAndGetAPIKey(ctx context.Context, apiKey string) (*APIKeyModel, error) { + key, err := s.client.APIKey.FindFirst( + APIKey.Key.Equals(apiKey), + APIKey.IsActive.Equals(true), + ).With( + APIKey.User.Fetch(), + ).Exec(ctx) + + if err != nil { + if errors.Is(err, ErrNotFound) { + return nil, errors.New("invalid API key") + } + return nil, err + } + + return key, nil +} + +func (s *APIKeyStore) RecordAPIUsage(ctx context.Context, apiKeyID string, userID string, tokens int) error { + now := time.Now() + startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) + + // First get current usage (outside transaction) + user, err := s.client.User.FindUnique( + User.ID.Equals(userID), + ).Exec(ctx) + if err != nil { + return err + } + + // Calculate new usage values + dailyUsage := user.DailyTokenUsage + if user.UsageResetAt.Before(startOfDay) { + dailyUsage = 0 + } + + // Create transaction for updates + updateUserQuery := s.client.User.FindUnique( + User.ID.Equals(userID), + ).Update( + User.DailyTokenUsage.Set(dailyUsage+tokens), + User.UsageResetAt.Set(now), + ).Tx() + + updateAPIKeyQuery := s.client.APIKey.FindUnique( + APIKey.ID.Equals(apiKeyID), + ).Update( + APIKey.DailyTokenUsage.Set(tokens), + APIKey.UsageResetAt.Set(now), + APIKey.LastUsedAt.Set(now), + ).Tx() + + upsertUsageQuery := s.client.DailyUsage.UpsertOne( + DailyUsage.DateUserIDAPIKeyID( + DailyUsage.Date.Equals(startOfDay), + DailyUsage.UserID.Equals(userID), + DailyUsage.APIKeyID.Equals(apiKeyID)), + ).Create( + DailyUsage.Date.Set(startOfDay), + DailyUsage.User.Link(User.ID.Equals(userID)), + DailyUsage.APIKeyID.Set(apiKeyID), + DailyUsage.TokenUsage.Set(tokens), + ).Update( + DailyUsage.TokenUsage.Increment(tokens), + ).Tx() + + // Execute transaction for updates only + err = s.client.Prisma.TX.Transaction( + updateUserQuery, + updateAPIKeyQuery, + upsertUsageQuery, + ).Exec(ctx) + + return err +} + +func (s *APIKeyStore) GetUserForAPIKey(ctx context.Context, apiKey string) (*UserModel, error) { + user, err := s.client.User.FindFirst( + User.APIKey.Some( + APIKey.Key.Equals(apiKey), + ), + ).Exec(ctx) + + if err != nil { + return nil, err + } + + return user, nil +} + +func (s *APIKeyStore) CheckUsageLimit(ctx context.Context, userID string, tokens int) (bool, error) { + user, err := s.client.User.FindUnique( + User.ID.Equals(userID), + ).Exec(ctx) + + if err != nil { + return false, err + } + + now := time.Now() + startOfDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) + + // Reset daily usage if it's a new day + dailyUsage := user.DailyTokenUsage + if user.UsageResetAt.Before(startOfDay) { + dailyUsage = 0 + } + + // Get limit based on user tier + var limit int + switch user.Tier { + case UserTierFree: + limit = 10000 + case UserTierPro: + limit = 100000 + case UserTierEnterprise: + limit = 1000000 + } + + return (dailyUsage + tokens) <= limit, nil +} diff --git a/db/schema.prisma b/db/schema.prisma new file mode 100644 index 0000000..746b4ce --- /dev/null +++ b/db/schema.prisma @@ -0,0 +1,104 @@ +datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +} + +generator db { + provider = "go run github.com/steebchen/prisma-client-go" + output = "./" +} + +model APIKey { + id String @id + key String @unique + name String? + userId String + type ApiKeyType @default(CUSTOMER) + dailyTokenUsage Int @default(0) + usageResetAt DateTime @default(now()) + createdAt DateTime @default(now()) + lastUsedAt DateTime @default(now()) + isActive Boolean @default(true) + User User @relation(fields: [userId], references: [id]) + DailyUsage DailyUsage[] + + @@index([key, isActive]) + @@index([userId, lastUsedAt]) + @@index([userId, type]) +} + +model Conversation { + id String @id + title String + userId String + createdAt DateTime @default(now()) + updatedAt DateTime + User User @relation(fields: [userId], references: [id]) + Message Message[] + + @@index([userId, updatedAt]) +} + +model DailyUsage { + id String @id @default(cuid()) + date DateTime @db.Date + tokenUsage Int @default(0) + userId String + apiKeyId String? + createdAt DateTime @default(now()) + APIKey APIKey? @relation(fields: [apiKeyId], references: [id]) + User User @relation(fields: [userId], references: [id]) + + @@unique([date, userId, apiKeyId]) + @@index([apiKeyId, date]) + @@index([userId, date]) +} + +model Message { + id String @id + orderIndex Int @default(autoincrement()) + content String + role MessageRole + conversationId String + createdAt DateTime @default(now()) + Conversation Conversation @relation(fields: [conversationId], references: [id]) + + @@index([conversationId, createdAt]) + @@index([conversationId, orderIndex]) +} + +model User { + id String @id + name String? + email String? @unique + image String? + loginType String? + dailyTokenUsage Int @default(0) + usageResetAt DateTime @default(now()) + tier UserTier @default(FREE) + isActive Boolean @default(true) + createdAt DateTime @default(now()) + updatedAt DateTime + APIKey APIKey[] + Conversation Conversation[] + DailyUsage DailyUsage[] + + @@index([email]) +} + +enum ApiKeyType { + MASTER + CUSTOMER +} + +enum MessageRole { + system + user + assistant +} + +enum UserTier { + FREE + PRO + ENTERPRISE +} diff --git a/go.mod b/go.mod index a5fb08f..e41cc1e 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,10 @@ require ( github.com/gin-gonic/gin v1.10.0 github.com/golang/glog v1.2.2 github.com/google/uuid v1.6.0 + github.com/joho/godotenv v1.5.1 github.com/livepeer/ai-worker v0.7.0 + github.com/shopspring/decimal v1.4.0 + github.com/steebchen/prisma-client-go v0.42.0 ) require ( diff --git a/go.sum b/go.sum index da6486a..f71eacc 100644 --- a/go.sum +++ b/go.sum @@ -69,6 +69,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY= github.com/invopop/yaml v0.2.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -119,9 +121,13 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= +github.com/steebchen/prisma-client-go v0.42.0 h1:83keN+4jGvoTccCKCk74UU5JQj6pOwPcg3/zkoqxKJE= +github.com/steebchen/prisma-client-go v0.42.0/go.mod h1:wp2xU9HO5WIefc65vcl1HOiFUzaHKyOhHw5atrzs8hc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/openai-api.go b/openai-api.go index c0a9b3b..8d6e7d3 100644 --- a/openai-api.go +++ b/openai-api.go @@ -4,6 +4,7 @@ import ( "flag" "log" + "github.com/livepool-io/openai-middleware/db" "github.com/livepool-io/openai-middleware/middleware" "github.com/livepool-io/openai-middleware/server" ) @@ -12,8 +13,12 @@ func main() { gatewayURL := flag.String("gateway", "http://your-api-host", "The URL of the gateway API") port := flag.String("port", "8080", "The port to run the server on") flag.Parse() + apiKeyStore, err := db.NewAPIKeyStore() + if err != nil { + log.Fatalf("Failed to create Supabase API key store: %v", err) + } gateway := middleware.NewGateway(*gatewayURL) - server, err := server.NewServer(gateway) + server, err := server.NewServer(apiKeyStore, gateway) if err != nil { log.Fatalf("Failed to create server: %v", err) } diff --git a/server/http.go b/server/http.go index 341f116..816eaa8 100644 --- a/server/http.go +++ b/server/http.go @@ -3,21 +3,26 @@ package server import ( + "context" + "encoding/json" "fmt" + "io" + "log" "net/http" "github.com/gin-gonic/gin" "github.com/livepool-io/openai-middleware/common" + "github.com/livepool-io/openai-middleware/db" "github.com/livepool-io/openai-middleware/middleware" "github.com/livepool-io/openai-middleware/models" ) type Server struct { - // We can add fields here later if needed, such as a database connection + db *db.APIKeyStore gateway *middleware.Gateway } -func NewServer(gateway *middleware.Gateway) (*Server, error) { +func NewServer(apiKeys *db.APIKeyStore, gateway *middleware.Gateway) (*Server, error) { if gateway == nil { return nil, fmt.Errorf("gateway is required") } @@ -28,8 +33,12 @@ func NewServer(gateway *middleware.Gateway) (*Server, error) { func (s *Server) Start(port string) error { r := gin.Default() - r.POST("/v1/chat/completions", s.handleChatCompletion) - r.POST("/v1/completions", s.handleCompletion) + protected := r.Group("/") + protected.Use(s.authRequired()) + { + protected.POST("/v1/chat/completions", s.handleChatCompletion) + protected.POST("/v1/completions", s.handleCompletion) + } r.GET("/v1/models", s.handleModels) r.GET("/v1/embeddings", s.handleEmbeddings) r.GET("/health", s.handleHealth) @@ -56,6 +65,24 @@ func (s *Server) handleChatCompletion(c *gin.Context) { return } + // Check usage allowed + apiKeyI, ok := c.Get("apiKey") + if !ok { + c.JSON(500, gin.H{"error": "API key not found"}) + return + } + apiKey := apiKeyI.(*db.APIKeyModel) + // check usage allowed + ok, err = s.db.CheckUsageLimit(c, apiKey.UserID, *req.MaxTokens) + if err != nil { + c.JSON(429, gin.H{"error": err.Error()}) + return + } + if !ok { + c.JSON(429, gin.H{"error": "Usage limit exceeded"}) + return + } + // Call gateway resp, err := s.gateway.PostLlmGenerate(*req) if err != nil { @@ -63,26 +90,63 @@ func (s *Server) handleChatCompletion(c *gin.Context) { return } + var tokensUsed int if openAIReq.Stream { c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") - // Handle streaming response and - // forward stream to caller in OpenAPI format - if err := s.gateway.HandleStreamingResponse(c.Request.Context(), req, c.Writer, resp); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + streamChan, errChan := common.HandleStreamingResponse(c.Request.Context(), req, resp) + + for { + select { + case <-c.Request.Context().Done(): + return + case err := <-errChan: + if err != io.EOF { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + return + case chunk := <-streamChan: + data, err := json.Marshal(chunk) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + _, err = fmt.Fprintf(c.Writer, "data: %s\n\n", data) + if err != nil { + return + } + c.Writer.(http.Flusher).Flush() + + // Record usage when we get the final chunk + if chunk.Choices[0].FinishReason == "stop" && chunk.Usage != nil { + tokensUsed = chunk.Usage.TotalTokens + } + } } } else { - // Transform Gateway API response to OpenAI format + // Handle non-streaming response openAIResp, err := common.TransformResponse(req, resp) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + + tokensUsed = openAIResp.Usage.TotalTokens + c.JSON(http.StatusOK, openAIResp) } + + go func() { + if err := s.db.RecordAPIUsage(context.Background(), + apiKey.ID, + apiKey.UserID, + tokensUsed); err != nil { + log.Printf("Failed to record API usage: %v", err) + } + }() } func (s *Server) handleCompletion(c *gin.Context) { @@ -100,3 +164,24 @@ func (s *Server) handleEmbeddings(c *gin.Context) { func (s *Server) handleHealth(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) } + +func (s *Server) authRequired() gin.HandlerFunc { + return func(c *gin.Context) { + apiKey := c.GetHeader("X-API-Key") + if apiKey == "" { + c.JSON(401, gin.H{"error": "API key is required"}) + c.Abort() + return + } + + key, err := s.db.ValidateAndGetAPIKey(c, apiKey) + if err != nil { + c.JSON(500, gin.H{"error": err.Error()}) + c.Abort() + return + } + + c.Set("apiKey", key) + c.Next() + } +}