draft grpc
This commit is contained in:
247
internal/middleware/authenticator.go
Normal file
247
internal/middleware/authenticator.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"stream.api/internal/database/model"
|
||||
"stream.api/internal/database/query"
|
||||
"stream.api/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
ActorMarkerMetadataKey = "x-stream-internal-auth"
|
||||
ActorIDMetadataKey = "x-stream-actor-id"
|
||||
ActorEmailMetadataKey = "x-stream-actor-email"
|
||||
ActorRoleMetadataKey = "x-stream-actor-role"
|
||||
)
|
||||
|
||||
type AuthResult struct {
|
||||
UserID string
|
||||
User *model.User
|
||||
}
|
||||
|
||||
type Actor struct {
|
||||
UserID string
|
||||
Email string
|
||||
Role string
|
||||
}
|
||||
|
||||
type Authenticator struct {
|
||||
db *gorm.DB
|
||||
logger logger.Logger
|
||||
trustedMarker string
|
||||
}
|
||||
|
||||
func NewAuthenticator(db *gorm.DB, l logger.Logger, trustedMarker string) *Authenticator {
|
||||
return &Authenticator{
|
||||
db: db,
|
||||
logger: l,
|
||||
trustedMarker: strings.TrimSpace(trustedMarker),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Authenticator) Authenticate(ctx context.Context) (*AuthResult, error) {
|
||||
actor, err := a.RequireActor(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u := query.User
|
||||
user, err := u.WithContext(ctx).Where(u.ID.Eq(actor.UserID)).First()
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "Unauthorized")
|
||||
}
|
||||
|
||||
user, err = a.syncSubscriptionState(ctx, user)
|
||||
if err != nil {
|
||||
a.logger.Error("Failed to sync subscription state", "error", err, "user_id", actor.UserID)
|
||||
return nil, status.Error(codes.Internal, "Failed to load user subscription state")
|
||||
}
|
||||
|
||||
if user.Role != nil && strings.EqualFold(strings.TrimSpace(*user.Role), "block") {
|
||||
return nil, status.Error(codes.PermissionDenied, "Forbidden: User is blocked")
|
||||
}
|
||||
|
||||
return &AuthResult{
|
||||
UserID: user.ID,
|
||||
User: user,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) RequireActor(ctx context.Context) (*Actor, error) {
|
||||
md, err := a.requireTrustedMetadata(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(firstMetadataValue(md, ActorIDMetadataKey))
|
||||
role := strings.TrimSpace(firstMetadataValue(md, ActorRoleMetadataKey))
|
||||
if userID == "" || role == "" {
|
||||
return nil, status.Error(codes.Unauthenticated, "Missing actor identity")
|
||||
}
|
||||
|
||||
return &Actor{
|
||||
UserID: userID,
|
||||
Email: strings.TrimSpace(firstMetadataValue(md, ActorEmailMetadataKey)),
|
||||
Role: role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) RequireInternalCall(ctx context.Context) error {
|
||||
_, err := a.requireTrustedMetadata(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *Authenticator) requireTrustedMetadata(ctx context.Context) (metadata.MD, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Unauthenticated, "Missing actor metadata")
|
||||
}
|
||||
|
||||
marker := firstMetadataValue(md, ActorMarkerMetadataKey)
|
||||
if marker == "" || marker != a.trustedMarker {
|
||||
return nil, status.Error(codes.Unauthenticated, "Invalid internal auth marker")
|
||||
}
|
||||
|
||||
return md, nil
|
||||
}
|
||||
|
||||
func firstMetadataValue(md metadata.MD, key string) string {
|
||||
values := md.Get(key)
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
return values[0]
|
||||
}
|
||||
|
||||
func (a *Authenticator) syncSubscriptionState(ctx context.Context, user *model.User) (*model.User, error) {
|
||||
subscription, err := model.GetLatestPlanSubscription(ctx, a.db, user.ID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return user, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if err := a.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var lockedSubscription model.PlanSubscription
|
||||
if err := tx.WithContext(ctx).
|
||||
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
Where("id = ?", subscription.ID).
|
||||
First(&lockedSubscription).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if lockedSubscription.ExpiresAt.After(now) {
|
||||
if user.PlanID == nil || strings.TrimSpace(*user.PlanID) != lockedSubscription.PlanID {
|
||||
if err := tx.WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
Where("id = ?", user.ID).
|
||||
Update("plan_id", lockedSubscription.PlanID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
user.PlanID = &lockedSubscription.PlanID
|
||||
}
|
||||
|
||||
reminderDays, reminderField := reminderFieldForSubscription(&lockedSubscription, now)
|
||||
if reminderField != "" {
|
||||
sentAt := now
|
||||
notification := &model.Notification{
|
||||
ID: uuidString(),
|
||||
UserID: user.ID,
|
||||
Type: "billing.subscription_expiring",
|
||||
Title: "Plan expiring soon",
|
||||
Message: reminderMessage(reminderDays),
|
||||
ActionURL: model.StringPtr("/settings/billing"),
|
||||
ActionLabel: model.StringPtr("Renew plan"),
|
||||
Metadata: model.StringPtr(mustMarshalAuthJSON(map[string]interface{}{"plan_id": lockedSubscription.PlanID, "expires_at": lockedSubscription.ExpiresAt.UTC().Format(time.RFC3339), "reminder_days": reminderDays})),
|
||||
}
|
||||
if err := tx.WithContext(ctx).Create(notification).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.WithContext(ctx).
|
||||
Model(&model.PlanSubscription{}).
|
||||
Where("id = ?", lockedSubscription.ID).
|
||||
Update(reminderField, sentAt).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if user.PlanID != nil && strings.TrimSpace(*user.PlanID) != "" {
|
||||
if err := tx.WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
Where("id = ?", user.ID).
|
||||
Update("plan_id", nil).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
user.PlanID = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func reminderFieldForSubscription(subscription *model.PlanSubscription, now time.Time) (int, string) {
|
||||
if subscription == nil || !subscription.ExpiresAt.After(now) {
|
||||
return 0, ""
|
||||
}
|
||||
|
||||
remaining := subscription.ExpiresAt.Sub(now)
|
||||
switch {
|
||||
case remaining <= 24*time.Hour:
|
||||
if subscription.Reminder1DSentAt == nil {
|
||||
return 1, "reminder_1d_sent_at"
|
||||
}
|
||||
case remaining <= 72*time.Hour:
|
||||
if subscription.Reminder3DSentAt == nil {
|
||||
return 3, "reminder_3d_sent_at"
|
||||
}
|
||||
case remaining <= 7*24*time.Hour:
|
||||
if subscription.Reminder7DSentAt == nil {
|
||||
return 7, "reminder_7d_sent_at"
|
||||
}
|
||||
}
|
||||
|
||||
return 0, ""
|
||||
}
|
||||
|
||||
func reminderMessage(days int) string {
|
||||
switch days {
|
||||
case 1:
|
||||
return "Your current plan expires in 1 day. Renew now to avoid interruption."
|
||||
case 3:
|
||||
return "Your current plan expires in 3 days. Renew now to keep access active."
|
||||
default:
|
||||
return "Your current plan expires in 7 days. Renew now to keep your plan active."
|
||||
}
|
||||
}
|
||||
|
||||
func mustMarshalAuthJSON(value interface{}) string {
|
||||
encoded, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return "{}"
|
||||
}
|
||||
return string(encoded)
|
||||
}
|
||||
|
||||
func uuidString() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
Reference in New Issue
Block a user