Remove unused gRPC and JWT related code, including Woodpecker service definitions and JWT token management.
This commit is contained in:
@@ -8,7 +8,7 @@ import (
|
|||||||
|
|
||||||
goredis "github.com/redis/go-redis/v9"
|
goredis "github.com/redis/go-redis/v9"
|
||||||
"stream.api/internal/database/model"
|
"stream.api/internal/database/model"
|
||||||
"stream.api/internal/video/runtime/domain"
|
"stream.api/internal/dto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -78,21 +78,21 @@ func (r *RedisAdapter) Dequeue(ctx context.Context) (*model.Job, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RedisAdapter) Publish(ctx context.Context, jobID string, logLine string, progress float64) error {
|
func (r *RedisAdapter) Publish(ctx context.Context, jobID string, logLine string, progress float64) error {
|
||||||
payload, err := json.Marshal(domain.LogEntry{JobID: jobID, Line: logLine, Progress: progress})
|
payload, err := json.Marshal(dto.LogEntry{JobID: jobID, Line: logLine, Progress: progress})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return r.client.Publish(ctx, LogChannel, payload).Err()
|
return r.client.Publish(ctx, LogChannel, payload).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RedisAdapter) Subscribe(ctx context.Context, jobID string) (<-chan domain.LogEntry, error) {
|
func (r *RedisAdapter) Subscribe(ctx context.Context, jobID string) (<-chan dto.LogEntry, error) {
|
||||||
pubsub := r.client.Subscribe(ctx, LogChannel)
|
pubsub := r.client.Subscribe(ctx, LogChannel)
|
||||||
ch := make(chan domain.LogEntry)
|
ch := make(chan dto.LogEntry)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
defer pubsub.Close()
|
defer pubsub.Close()
|
||||||
for msg := range pubsub.Channel() {
|
for msg := range pubsub.Channel() {
|
||||||
var entry domain.LogEntry
|
var entry dto.LogEntry
|
||||||
if err := json.Unmarshal([]byte(msg.Payload), &entry); err != nil {
|
if err := json.Unmarshal([]byte(msg.Payload), &entry); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -112,21 +112,21 @@ func (r *RedisAdapter) PublishResource(ctx context.Context, agentID string, data
|
|||||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
payload, err := json.Marshal(domain.SystemResource{AgentID: agentID, CPU: decoded.CPU, RAM: decoded.RAM})
|
payload, err := json.Marshal(dto.SystemResource{AgentID: agentID, CPU: decoded.CPU, RAM: decoded.RAM})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return r.client.Publish(ctx, ResourceChannel, payload).Err()
|
return r.client.Publish(ctx, ResourceChannel, payload).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RedisAdapter) SubscribeResources(ctx context.Context) (<-chan domain.SystemResource, error) {
|
func (r *RedisAdapter) SubscribeResources(ctx context.Context) (<-chan dto.SystemResource, error) {
|
||||||
pubsub := r.client.Subscribe(ctx, ResourceChannel)
|
pubsub := r.client.Subscribe(ctx, ResourceChannel)
|
||||||
ch := make(chan domain.SystemResource)
|
ch := make(chan dto.SystemResource)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
defer pubsub.Close()
|
defer pubsub.Close()
|
||||||
for msg := range pubsub.Channel() {
|
for msg := range pubsub.Channel() {
|
||||||
var entry domain.SystemResource
|
var entry dto.SystemResource
|
||||||
if err := json.Unmarshal([]byte(msg.Payload), &entry); err != nil {
|
if err := json.Unmarshal([]byte(msg.Payload), &entry); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package domain
|
package dto
|
||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
@@ -24,3 +24,7 @@ type Agent struct {
|
|||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
type AgentWithStats struct {
|
||||||
|
*Agent
|
||||||
|
ActiveJobCount int64 `json:"active_job_count"`
|
||||||
|
}
|
||||||
46
internal/dto/job.go
Normal file
46
internal/dto/job.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import "stream.api/internal/database/model"
|
||||||
|
|
||||||
|
type JobStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
JobStatusPending JobStatus = "pending"
|
||||||
|
JobStatusRunning JobStatus = "running"
|
||||||
|
JobStatusSuccess JobStatus = "success"
|
||||||
|
JobStatusFailure JobStatus = "failure"
|
||||||
|
JobStatusCancelled JobStatus = "cancelled"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultJobPageSize = 20
|
||||||
|
MaxJobPageSize = 100
|
||||||
|
JobCursorVersion = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
type PaginatedJobs struct {
|
||||||
|
Jobs []*model.Job `json:"jobs"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Offset int `json:"offset"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
HasMore bool `json:"has_more"`
|
||||||
|
NextCursor string `json:"next_cursor,omitempty"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JobListCursor struct {
|
||||||
|
Version int `json:"v"`
|
||||||
|
CreatedAtUnixNano int64 `json:"created_at_unix_nano"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
AgentID string `json:"agent_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type JobConfigEnvelope struct {
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
|
Commands []string `json:"commands,omitempty"`
|
||||||
|
Environment map[string]string `json:"environment,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
UserID string `json:"user_id,omitempty"`
|
||||||
|
VideoID string `json:"video_id,omitempty"`
|
||||||
|
TimeLimit int64 `json:"time_limit,omitempty"`
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package domain
|
package dto
|
||||||
|
|
||||||
type LogEntry struct {
|
type LogEntry struct {
|
||||||
JobID string `json:"job_id"`
|
JobID string `json:"job_id"`
|
||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
appv1 "stream.api/internal/api/proto/app/v1"
|
appv1 "stream.api/internal/api/proto/app/v1"
|
||||||
"stream.api/internal/database/model"
|
"stream.api/internal/database/model"
|
||||||
"stream.api/internal/video"
|
runtimeservices "stream.api/internal/service/runtime/services"
|
||||||
runtimeservices "stream.api/internal/video/runtime/services"
|
"stream.api/internal/service/video"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestListAdminJobsCursorPagination(t *testing.T) {
|
func TestListAdminJobsCursorPagination(t *testing.T) {
|
||||||
|
|||||||
619
internal/service/admin_helpers.go
Normal file
619
internal/service/admin_helpers.go
Normal file
@@ -0,0 +1,619 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
appv1 "stream.api/internal/api/proto/app/v1"
|
||||||
|
"stream.api/internal/database/model"
|
||||||
|
"stream.api/internal/dto"
|
||||||
|
"stream.api/internal/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *appServices) requireAdmin(ctx context.Context) (*middleware.AuthResult, error) {
|
||||||
|
result, err := s.authenticate(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if result.User == nil || result.User.Role == nil || strings.ToUpper(strings.TrimSpace(*result.User.Role)) != "ADMIN" {
|
||||||
|
return nil, status.Error(codes.PermissionDenied, "Admin access required")
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) ensurePlanExists(ctx context.Context, planID *string) error {
|
||||||
|
if planID == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
trimmed := strings.TrimSpace(*planID)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var count int64
|
||||||
|
if err := s.db.WithContext(ctx).Model(&model.Plan{}).Where("id = ?", trimmed).Count(&count).Error; err != nil {
|
||||||
|
return status.Error(codes.Internal, "Failed to validate plan")
|
||||||
|
}
|
||||||
|
if count == 0 {
|
||||||
|
return status.Error(codes.InvalidArgument, "Plan not found")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) saveAdminVideoAdConfig(ctx context.Context, tx *gorm.DB, video *model.Video, userID string, adTemplateID *string) error {
|
||||||
|
if video == nil || adTemplateID == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(*adTemplateID)
|
||||||
|
if trimmed == "" {
|
||||||
|
if err := tx.WithContext(ctx).Model(&model.Video{}).Where("id = ?", video.ID).Update("ad_id", nil).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
video.AdID = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var template model.AdTemplate
|
||||||
|
if err := tx.WithContext(ctx).Select("id").Where("id = ? AND user_id = ?", trimmed, userID).First(&template).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return errors.New("Ad template not found")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.WithContext(ctx).Model(&model.Video{}).Where("id = ?", video.ID).Update("ad_id", template.ID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
video.AdID = &template.ID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func adminPageLimitOffset(pageValue int32, limitValue int32) (int32, int32, int) {
|
||||||
|
page := pageValue
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
limit := limitValue
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 20
|
||||||
|
}
|
||||||
|
if limit > 100 {
|
||||||
|
limit = 100
|
||||||
|
}
|
||||||
|
offset := int((page - 1) * limit)
|
||||||
|
return page, limit, offset
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAdminJob(job *model.Job) *appv1.AdminJob {
|
||||||
|
if job == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
agentID := strconv.FormatInt(*job.AgentID, 10)
|
||||||
|
return &appv1.AdminJob{
|
||||||
|
Id: job.ID,
|
||||||
|
Status: string(*job.Status),
|
||||||
|
Priority: int32(*job.Priority),
|
||||||
|
UserId: *job.UserID,
|
||||||
|
Name: job.ID,
|
||||||
|
TimeLimit: *job.TimeLimit,
|
||||||
|
InputUrl: *job.InputURL,
|
||||||
|
OutputUrl: *job.OutputURL,
|
||||||
|
TotalDuration: *job.TotalDuration,
|
||||||
|
CurrentTime: *job.CurrentTime,
|
||||||
|
Progress: *job.Progress,
|
||||||
|
AgentId: &agentID,
|
||||||
|
Logs: *job.Logs,
|
||||||
|
Config: *job.Config,
|
||||||
|
Cancelled: *job.Cancelled,
|
||||||
|
RetryCount: int32(*job.RetryCount),
|
||||||
|
MaxRetries: int32(*job.MaxRetries),
|
||||||
|
CreatedAt: timestamppb.New(*job.CreatedAt),
|
||||||
|
UpdatedAt: timestamppb.New(*job.UpdatedAt),
|
||||||
|
VideoId: stringPointerOrNil(*job.VideoID),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAdminAgent(agent *dto.AgentWithStats) *appv1.AdminAgent {
|
||||||
|
if agent == nil || agent.Agent == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &appv1.AdminAgent{
|
||||||
|
Id: agent.ID,
|
||||||
|
Name: agent.Name,
|
||||||
|
Platform: agent.Platform,
|
||||||
|
Backend: agent.Backend,
|
||||||
|
Version: agent.Version,
|
||||||
|
Capacity: agent.Capacity,
|
||||||
|
Status: string(agent.Status),
|
||||||
|
Cpu: agent.CPU,
|
||||||
|
Ram: agent.RAM,
|
||||||
|
LastHeartbeat: timestamppb.New(agent.LastHeartbeat),
|
||||||
|
CreatedAt: timestamppb.New(agent.CreatedAt),
|
||||||
|
UpdatedAt: timestamppb.New(agent.UpdatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAdminRoleValue(value string) string {
|
||||||
|
role := strings.ToUpper(strings.TrimSpace(value))
|
||||||
|
if role == "" {
|
||||||
|
return "USER"
|
||||||
|
}
|
||||||
|
return role
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidAdminRoleValue(role string) bool {
|
||||||
|
switch normalizeAdminRoleValue(role) {
|
||||||
|
case "USER", "ADMIN", "BLOCK":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) buildAdminUser(ctx context.Context, user *model.User) (*appv1.AdminUser, error) {
|
||||||
|
if user == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := &appv1.AdminUser{
|
||||||
|
Id: user.ID,
|
||||||
|
Email: user.Email,
|
||||||
|
Username: nullableTrimmedString(user.Username),
|
||||||
|
Avatar: nullableTrimmedString(user.Avatar),
|
||||||
|
Role: nullableTrimmedString(user.Role),
|
||||||
|
PlanId: nullableTrimmedString(user.PlanID),
|
||||||
|
StorageUsed: user.StorageUsed,
|
||||||
|
CreatedAt: timeToProto(user.CreatedAt),
|
||||||
|
UpdatedAt: timestamppb.New(user.UpdatedAt.UTC()),
|
||||||
|
WalletBalance: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
videoCount, err := s.loadAdminUserVideoCount(ctx, user.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload.VideoCount = videoCount
|
||||||
|
|
||||||
|
walletBalance, err := model.GetWalletBalance(ctx, s.db, user.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload.WalletBalance = walletBalance
|
||||||
|
|
||||||
|
planName, err := s.loadAdminPlanName(ctx, user.PlanID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload.PlanName = planName
|
||||||
|
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) buildAdminUserDetail(ctx context.Context, user *model.User, subscription *model.PlanSubscription) (*appv1.AdminUserDetail, error) {
|
||||||
|
payload, err := s.buildAdminUser(ctx, user)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
referral, err := s.buildAdminUserReferralInfo(ctx, user)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &appv1.AdminUserDetail{
|
||||||
|
User: payload,
|
||||||
|
Subscription: toProtoPlanSubscription(subscription),
|
||||||
|
Referral: referral,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) buildAdminUserReferralInfo(ctx context.Context, user *model.User) (*appv1.AdminUserReferralInfo, error) {
|
||||||
|
if user == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var referrer *appv1.ReferralUserSummary
|
||||||
|
if user.ReferredByUserID != nil && strings.TrimSpace(*user.ReferredByUserID) != "" {
|
||||||
|
loadedReferrer, err := s.loadReferralUserSummary(ctx, strings.TrimSpace(*user.ReferredByUserID))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
referrer = loadedReferrer
|
||||||
|
}
|
||||||
|
|
||||||
|
bps := effectiveReferralRewardBps(user.ReferralRewardBps)
|
||||||
|
referral := &appv1.AdminUserReferralInfo{
|
||||||
|
Referrer: referrer,
|
||||||
|
ReferralEligible: referralUserEligible(user),
|
||||||
|
EffectiveRewardPercent: referralRewardBpsToPercent(bps),
|
||||||
|
RewardOverridePercent: func() *float64 {
|
||||||
|
if user.ReferralRewardBps == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
value := referralRewardBpsToPercent(*user.ReferralRewardBps)
|
||||||
|
return &value
|
||||||
|
}(),
|
||||||
|
ShareLink: s.buildReferralShareLink(user.Username),
|
||||||
|
RewardGranted: referralRewardProcessed(user),
|
||||||
|
RewardGrantedAt: timeToProto(user.ReferralRewardGrantedAt),
|
||||||
|
RewardPaymentId: nullableTrimmedString(user.ReferralRewardPaymentID),
|
||||||
|
RewardAmount: user.ReferralRewardAmount,
|
||||||
|
}
|
||||||
|
return referral, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) buildAdminVideo(ctx context.Context, video *model.Video) (*appv1.AdminVideo, error) {
|
||||||
|
if video == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
statusValue := stringValue(video.Status)
|
||||||
|
if statusValue == "" {
|
||||||
|
statusValue = "ready"
|
||||||
|
}
|
||||||
|
jobID, err := s.loadLatestVideoJobID(ctx, video.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := &appv1.AdminVideo{
|
||||||
|
Id: video.ID,
|
||||||
|
UserId: video.UserID,
|
||||||
|
Title: video.Title,
|
||||||
|
Description: nullableTrimmedString(video.Description),
|
||||||
|
Url: video.URL,
|
||||||
|
Status: strings.ToLower(statusValue),
|
||||||
|
Size: video.Size,
|
||||||
|
Duration: video.Duration,
|
||||||
|
Format: video.Format,
|
||||||
|
CreatedAt: timeToProto(video.CreatedAt),
|
||||||
|
UpdatedAt: timestamppb.New(video.UpdatedAt.UTC()),
|
||||||
|
ProcessingStatus: nullableTrimmedString(video.ProcessingStatus),
|
||||||
|
JobId: jobID,
|
||||||
|
}
|
||||||
|
|
||||||
|
ownerEmail, err := s.loadAdminUserEmail(ctx, video.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload.OwnerEmail = ownerEmail
|
||||||
|
|
||||||
|
adTemplateID, adTemplateName, err := s.loadAdminVideoAdTemplateDetails(ctx, video)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload.AdTemplateId = adTemplateID
|
||||||
|
payload.AdTemplateName = adTemplateName
|
||||||
|
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) buildAdminPayment(ctx context.Context, payment *model.Payment) (*appv1.AdminPayment, error) {
|
||||||
|
if payment == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := &appv1.AdminPayment{
|
||||||
|
Id: payment.ID,
|
||||||
|
UserId: payment.UserID,
|
||||||
|
PlanId: nullableTrimmedString(payment.PlanID),
|
||||||
|
Amount: payment.Amount,
|
||||||
|
Currency: normalizeCurrency(payment.Currency),
|
||||||
|
Status: normalizePaymentStatus(payment.Status),
|
||||||
|
Provider: strings.ToUpper(stringValue(payment.Provider)),
|
||||||
|
TransactionId: nullableTrimmedString(payment.TransactionID),
|
||||||
|
InvoiceId: payment.ID,
|
||||||
|
CreatedAt: timeToProto(payment.CreatedAt),
|
||||||
|
UpdatedAt: timestamppb.New(payment.UpdatedAt.UTC()),
|
||||||
|
}
|
||||||
|
|
||||||
|
userEmail, err := s.loadAdminUserEmail(ctx, payment.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload.UserEmail = userEmail
|
||||||
|
|
||||||
|
planName, err := s.loadAdminPlanName(ctx, payment.PlanID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload.PlanName = planName
|
||||||
|
|
||||||
|
termMonths, paymentMethod, expiresAt, walletAmount, topupAmount, err := s.loadAdminPaymentSubscriptionDetails(ctx, payment.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload.TermMonths = termMonths
|
||||||
|
payload.PaymentMethod = paymentMethod
|
||||||
|
payload.ExpiresAt = expiresAt
|
||||||
|
payload.WalletAmount = walletAmount
|
||||||
|
payload.TopupAmount = topupAmount
|
||||||
|
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadAdminUserVideoCount(ctx context.Context, userID string) (int64, error) {
|
||||||
|
var videoCount int64
|
||||||
|
if err := s.db.WithContext(ctx).Model(&model.Video{}).Where("user_id = ?", userID).Count(&videoCount).Error; err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return videoCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadAdminUserEmail(ctx context.Context, userID string) (*string, error) {
|
||||||
|
var user model.User
|
||||||
|
if err := s.db.WithContext(ctx).Select("id, email").Where("id = ?", userID).First(&user).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nullableTrimmedString(&user.Email), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadReferralUserSummary(ctx context.Context, userID string) (*appv1.ReferralUserSummary, error) {
|
||||||
|
if strings.TrimSpace(userID) == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var user model.User
|
||||||
|
if err := s.db.WithContext(ctx).Select("id, email, username").Where("id = ?", userID).First(&user).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &appv1.ReferralUserSummary{
|
||||||
|
Id: user.ID,
|
||||||
|
Email: user.Email,
|
||||||
|
Username: nullableTrimmedString(user.Username),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadAdminPlanName(ctx context.Context, planID *string) (*string, error) {
|
||||||
|
if planID == nil || strings.TrimSpace(*planID) == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var plan model.Plan
|
||||||
|
if err := s.db.WithContext(ctx).Select("id, name").Where("id = ?", *planID).First(&plan).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nullableTrimmedString(&plan.Name), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadAdminVideoAdTemplateDetails(ctx context.Context, video *model.Video) (*string, *string, error) {
|
||||||
|
if video == nil {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
adTemplateID := nullableTrimmedString(video.AdID)
|
||||||
|
if adTemplateID == nil {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
adTemplateName, err := s.loadAdminAdTemplateName(ctx, *adTemplateID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return adTemplateID, adTemplateName, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadAdminAdTemplateName(ctx context.Context, adTemplateID string) (*string, error) {
|
||||||
|
var template model.AdTemplate
|
||||||
|
if err := s.db.WithContext(ctx).Select("id, name").Where("id = ?", adTemplateID).First(&template).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nullableTrimmedString(&template.Name), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadLatestVideoJobID(ctx context.Context, videoID string) (*string, error) {
|
||||||
|
videoID = strings.TrimSpace(videoID)
|
||||||
|
if videoID == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var job model.Job
|
||||||
|
if err := s.db.WithContext(ctx).
|
||||||
|
Where("config::jsonb ->> 'video_id' = ?", videoID).
|
||||||
|
Order("created_at DESC").
|
||||||
|
First(&job).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return stringPointerOrNil(job.ID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadAdminPaymentSubscriptionDetails(ctx context.Context, paymentID string) (*int32, *string, *string, *float64, *float64, error) {
|
||||||
|
var subscription model.PlanSubscription
|
||||||
|
if err := s.db.WithContext(ctx).Where("payment_id = ?", paymentID).Order("created_at DESC").First(&subscription).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil, nil, nil, nil, nil
|
||||||
|
}
|
||||||
|
return nil, nil, nil, nil, nil, err
|
||||||
|
}
|
||||||
|
termMonths := subscription.TermMonths
|
||||||
|
paymentMethod := nullableTrimmedString(&subscription.PaymentMethod)
|
||||||
|
expiresAt := subscription.ExpiresAt.UTC().Format(time.RFC3339)
|
||||||
|
walletAmount := subscription.WalletAmount
|
||||||
|
topupAmount := subscription.TopupAmount
|
||||||
|
return &termMonths, paymentMethod, nullableTrimmedString(&expiresAt), &walletAmount, &topupAmount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadAdminPlanUsageCounts(ctx context.Context, planID string) (int64, int64, int64, error) {
|
||||||
|
var userCount int64
|
||||||
|
if err := s.db.WithContext(ctx).Model(&model.User{}).Where("plan_id = ?", planID).Count(&userCount).Error; err != nil {
|
||||||
|
return 0, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var paymentCount int64
|
||||||
|
if err := s.db.WithContext(ctx).Model(&model.Payment{}).Where("plan_id = ?", planID).Count(&paymentCount).Error; err != nil {
|
||||||
|
return 0, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var subscriptionCount int64
|
||||||
|
if err := s.db.WithContext(ctx).Model(&model.PlanSubscription{}).Where("plan_id = ?", planID).Count(&subscriptionCount).Error; err != nil {
|
||||||
|
return 0, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return userCount, paymentCount, subscriptionCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateAdminPlanInput(name, cycle string, price float64, storageLimit int64, uploadLimit int32) string {
|
||||||
|
if strings.TrimSpace(name) == "" {
|
||||||
|
return "Name is required"
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(cycle) == "" {
|
||||||
|
return "Cycle is required"
|
||||||
|
}
|
||||||
|
if price < 0 {
|
||||||
|
return "Price must be greater than or equal to 0"
|
||||||
|
}
|
||||||
|
if storageLimit <= 0 {
|
||||||
|
return "Storage limit must be greater than 0"
|
||||||
|
}
|
||||||
|
if uploadLimit <= 0 {
|
||||||
|
return "Upload limit must be greater than 0"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateAdminAdTemplateInput(userID, name, vastTagURL, adFormat string, duration *int64) string {
|
||||||
|
if strings.TrimSpace(userID) == "" {
|
||||||
|
return "User ID is required"
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(name) == "" || strings.TrimSpace(vastTagURL) == "" {
|
||||||
|
return "Name and VAST URL are required"
|
||||||
|
}
|
||||||
|
format := normalizeAdFormat(adFormat)
|
||||||
|
if format == "mid-roll" && (duration == nil || *duration <= 0) {
|
||||||
|
return "Duration is required for mid-roll templates"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateAdminPlayerConfigInput(userID, name string) string {
|
||||||
|
if strings.TrimSpace(userID) == "" {
|
||||||
|
return "User ID is required"
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(name) == "" {
|
||||||
|
return "Name is required"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) unsetAdminDefaultTemplates(ctx context.Context, tx *gorm.DB, userID, excludeID string) error {
|
||||||
|
query := tx.WithContext(ctx).Model(&model.AdTemplate{}).Where("user_id = ?", userID)
|
||||||
|
if excludeID != "" {
|
||||||
|
query = query.Where("id <> ?", excludeID)
|
||||||
|
}
|
||||||
|
return query.Update("is_default", false).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) unsetAdminDefaultPlayerConfigs(ctx context.Context, tx *gorm.DB, userID, excludeID string) error {
|
||||||
|
query := tx.WithContext(ctx).Model(&model.PlayerConfig{}).Where("user_id = ?", userID)
|
||||||
|
if excludeID != "" {
|
||||||
|
query = query.Where("id <> ?", excludeID)
|
||||||
|
}
|
||||||
|
return query.Update("is_default", false).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) buildAdminPlan(ctx context.Context, plan *model.Plan) (*appv1.AdminPlan, error) {
|
||||||
|
if plan == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
userCount, paymentCount, subscriptionCount, err := s.loadAdminPlanUsageCounts(ctx, plan.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := &appv1.AdminPlan{
|
||||||
|
Id: plan.ID,
|
||||||
|
Name: plan.Name,
|
||||||
|
Description: nullableTrimmedString(plan.Description),
|
||||||
|
Features: append([]string(nil), plan.Features...),
|
||||||
|
Price: plan.Price,
|
||||||
|
Cycle: plan.Cycle,
|
||||||
|
StorageLimit: plan.StorageLimit,
|
||||||
|
UploadLimit: plan.UploadLimit,
|
||||||
|
DurationLimit: plan.DurationLimit,
|
||||||
|
QualityLimit: plan.QualityLimit,
|
||||||
|
IsActive: boolValue(plan.IsActive),
|
||||||
|
UserCount: userCount,
|
||||||
|
PaymentCount: paymentCount,
|
||||||
|
SubscriptionCount: subscriptionCount,
|
||||||
|
}
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) buildAdminAdTemplate(ctx context.Context, item *model.AdTemplate) (*appv1.AdminAdTemplate, error) {
|
||||||
|
if item == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := &appv1.AdminAdTemplate{
|
||||||
|
Id: item.ID,
|
||||||
|
UserId: item.UserID,
|
||||||
|
Name: item.Name,
|
||||||
|
Description: nullableTrimmedString(item.Description),
|
||||||
|
VastTagUrl: item.VastTagURL,
|
||||||
|
AdFormat: stringValue(item.AdFormat),
|
||||||
|
Duration: item.Duration,
|
||||||
|
IsActive: boolValue(item.IsActive),
|
||||||
|
IsDefault: item.IsDefault,
|
||||||
|
CreatedAt: timeToProto(item.CreatedAt),
|
||||||
|
UpdatedAt: timeToProto(item.UpdatedAt),
|
||||||
|
}
|
||||||
|
|
||||||
|
ownerEmail, err := s.loadAdminUserEmail(ctx, item.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload.OwnerEmail = ownerEmail
|
||||||
|
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) buildAdminPlayerConfig(ctx context.Context, item *model.PlayerConfig) (*appv1.AdminPlayerConfig, error) {
|
||||||
|
if item == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := &appv1.AdminPlayerConfig{
|
||||||
|
Id: item.ID,
|
||||||
|
UserId: item.UserID,
|
||||||
|
Name: item.Name,
|
||||||
|
Description: nullableTrimmedString(item.Description),
|
||||||
|
Autoplay: item.Autoplay,
|
||||||
|
Loop: item.Loop,
|
||||||
|
Muted: item.Muted,
|
||||||
|
ShowControls: boolValue(item.ShowControls),
|
||||||
|
Pip: boolValue(item.Pip),
|
||||||
|
Airplay: boolValue(item.Airplay),
|
||||||
|
Chromecast: boolValue(item.Chromecast),
|
||||||
|
IsActive: boolValue(item.IsActive),
|
||||||
|
IsDefault: item.IsDefault,
|
||||||
|
CreatedAt: timeToProto(item.CreatedAt),
|
||||||
|
UpdatedAt: timeToProto(&item.UpdatedAt),
|
||||||
|
EncrytionM3U8: boolValue(item.EncrytionM3u8),
|
||||||
|
LogoUrl: nullableTrimmedString(item.LogoURL),
|
||||||
|
}
|
||||||
|
|
||||||
|
ownerEmail, err := s.loadAdminUserEmail(ctx, item.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload.OwnerEmail = ownerEmail
|
||||||
|
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
417
internal/service/payment_helpers.go
Normal file
417
internal/service/payment_helpers.go
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
appv1 "stream.api/internal/api/proto/app/v1"
|
||||||
|
"stream.api/internal/database/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func statusErrorWithBody(ctx context.Context, grpcCode codes.Code, httpCode int, message string, data any) error {
|
||||||
|
body := apiErrorBody{
|
||||||
|
Code: httpCode,
|
||||||
|
Message: message,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
encoded, err := json.Marshal(body)
|
||||||
|
if err == nil {
|
||||||
|
_ = grpc.SetTrailer(ctx, metadata.Pairs("x-error-body", string(encoded)))
|
||||||
|
}
|
||||||
|
return status.Error(grpcCode, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadPaymentPlanForUser(ctx context.Context, planID string) (*model.Plan, error) {
|
||||||
|
var planRecord model.Plan
|
||||||
|
if err := s.db.WithContext(ctx).Where("id = ?", planID).First(&planRecord).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Error(codes.NotFound, "Plan not found")
|
||||||
|
}
|
||||||
|
s.logger.Error("Failed to load plan", "error", err)
|
||||||
|
return nil, status.Error(codes.Internal, "Failed to create payment")
|
||||||
|
}
|
||||||
|
if planRecord.IsActive == nil || !*planRecord.IsActive {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "Plan is not active")
|
||||||
|
}
|
||||||
|
return &planRecord, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadPaymentPlanForAdmin(ctx context.Context, planID string) (*model.Plan, error) {
|
||||||
|
var planRecord model.Plan
|
||||||
|
if err := s.db.WithContext(ctx).Where("id = ?", planID).First(&planRecord).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "Plan not found")
|
||||||
|
}
|
||||||
|
return nil, status.Error(codes.Internal, "Failed to create payment")
|
||||||
|
}
|
||||||
|
if planRecord.IsActive == nil || !*planRecord.IsActive {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "Plan is not active")
|
||||||
|
}
|
||||||
|
return &planRecord, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadPaymentUserForAdmin(ctx context.Context, userID string) (*model.User, error) {
|
||||||
|
var user model.User
|
||||||
|
if err := s.db.WithContext(ctx).Where("id = ?", userID).First(&user).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "User not found")
|
||||||
|
}
|
||||||
|
return nil, status.Error(codes.Internal, "Failed to create payment")
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) executePaymentFlow(ctx context.Context, input paymentExecutionInput) (*paymentExecutionResult, error) {
|
||||||
|
totalAmount := input.Plan.Price * float64(input.TermMonths)
|
||||||
|
if totalAmount < 0 {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "Amount must be greater than or equal to 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
statusValue := "SUCCESS"
|
||||||
|
provider := "INTERNAL"
|
||||||
|
currency := normalizeCurrency(nil)
|
||||||
|
transactionID := buildTransactionID("sub")
|
||||||
|
now := time.Now().UTC()
|
||||||
|
paymentRecord := &model.Payment{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
UserID: input.UserID,
|
||||||
|
PlanID: &input.Plan.ID,
|
||||||
|
Amount: totalAmount,
|
||||||
|
Currency: ¤cy,
|
||||||
|
Status: &statusValue,
|
||||||
|
Provider: &provider,
|
||||||
|
TransactionID: &transactionID,
|
||||||
|
}
|
||||||
|
invoiceID := buildInvoiceID(paymentRecord.ID)
|
||||||
|
|
||||||
|
result := &paymentExecutionResult{
|
||||||
|
Payment: paymentRecord,
|
||||||
|
InvoiceID: invoiceID,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||||
|
if _, err := lockUserForUpdate(ctx, tx, input.UserID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newExpiry, err := loadPaymentExpiry(ctx, tx, input.UserID, input.TermMonths, now)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
currentWalletBalance, err := model.GetWalletBalance(ctx, tx, input.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
validatedTopupAmount, err := validatePaymentFunding(ctx, input, totalAmount, currentWalletBalance)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Create(paymentRecord).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := createPaymentWalletTransactions(tx, input, paymentRecord, totalAmount, validatedTopupAmount, currency); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
subscription := buildPaymentSubscription(input, paymentRecord, totalAmount, validatedTopupAmount, now, newExpiry)
|
||||||
|
if err := tx.Create(subscription).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Model(&model.User{}).Where("id = ?", input.UserID).Update("plan_id", input.Plan.ID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
notification := buildSubscriptionNotification(input.UserID, paymentRecord.ID, invoiceID, input.Plan, subscription)
|
||||||
|
if err := tx.Create(notification).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := s.maybeGrantReferralReward(ctx, tx, input, paymentRecord, subscription); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
walletBalance, err := model.GetWalletBalance(ctx, tx, input.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result.Subscription = subscription
|
||||||
|
result.WalletBalance = walletBalance
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadPaymentExpiry(ctx context.Context, tx *gorm.DB, userID string, termMonths int32, now time.Time) (time.Time, error) {
|
||||||
|
currentSubscription, err := model.GetLatestPlanSubscription(ctx, tx, userID)
|
||||||
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return time.Time{}, err
|
||||||
|
}
|
||||||
|
baseExpiry := now
|
||||||
|
if currentSubscription != nil && currentSubscription.ExpiresAt.After(baseExpiry) {
|
||||||
|
baseExpiry = currentSubscription.ExpiresAt.UTC()
|
||||||
|
}
|
||||||
|
return baseExpiry.AddDate(0, int(termMonths), 0), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validatePaymentFunding(ctx context.Context, input paymentExecutionInput, totalAmount, currentWalletBalance float64) (float64, error) {
|
||||||
|
shortfall := maxFloat(totalAmount-currentWalletBalance, 0)
|
||||||
|
if input.PaymentMethod == paymentMethodWallet && shortfall > 0 {
|
||||||
|
return 0, statusErrorWithBody(ctx, codes.InvalidArgument, http.StatusBadRequest, "Insufficient wallet balance", map[string]any{
|
||||||
|
"payment_method": input.PaymentMethod,
|
||||||
|
"wallet_balance": currentWalletBalance,
|
||||||
|
"total_amount": totalAmount,
|
||||||
|
"shortfall": shortfall,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if input.PaymentMethod != paymentMethodTopup {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if input.TopupAmount == nil {
|
||||||
|
return 0, statusErrorWithBody(ctx, codes.InvalidArgument, http.StatusBadRequest, "Top-up amount is required when payment method is topup", map[string]any{
|
||||||
|
"payment_method": input.PaymentMethod,
|
||||||
|
"wallet_balance": currentWalletBalance,
|
||||||
|
"total_amount": totalAmount,
|
||||||
|
"shortfall": shortfall,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
topupAmount := maxFloat(*input.TopupAmount, 0)
|
||||||
|
if topupAmount <= 0 {
|
||||||
|
return 0, statusErrorWithBody(ctx, codes.InvalidArgument, http.StatusBadRequest, "Top-up amount must be greater than 0", map[string]any{
|
||||||
|
"payment_method": input.PaymentMethod,
|
||||||
|
"wallet_balance": currentWalletBalance,
|
||||||
|
"total_amount": totalAmount,
|
||||||
|
"shortfall": shortfall,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if topupAmount < shortfall {
|
||||||
|
return 0, statusErrorWithBody(ctx, codes.InvalidArgument, http.StatusBadRequest, "Top-up amount must be greater than or equal to the required shortfall", map[string]any{
|
||||||
|
"payment_method": input.PaymentMethod,
|
||||||
|
"wallet_balance": currentWalletBalance,
|
||||||
|
"total_amount": totalAmount,
|
||||||
|
"shortfall": shortfall,
|
||||||
|
"topup_amount": topupAmount,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return topupAmount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createPaymentWalletTransactions(tx *gorm.DB, input paymentExecutionInput, paymentRecord *model.Payment, totalAmount, topupAmount float64, currency string) error {
|
||||||
|
if input.PaymentMethod == paymentMethodTopup {
|
||||||
|
topupTransaction := &model.WalletTransaction{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
UserID: input.UserID,
|
||||||
|
Type: walletTransactionTypeTopup,
|
||||||
|
Amount: topupAmount,
|
||||||
|
Currency: model.StringPtr(currency),
|
||||||
|
Note: model.StringPtr(fmt.Sprintf("Wallet top-up for %s (%d months)", input.Plan.Name, input.TermMonths)),
|
||||||
|
PaymentID: &paymentRecord.ID,
|
||||||
|
PlanID: &input.Plan.ID,
|
||||||
|
TermMonths: int32Ptr(input.TermMonths),
|
||||||
|
}
|
||||||
|
if err := tx.Create(topupTransaction).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
debitTransaction := &model.WalletTransaction{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
UserID: input.UserID,
|
||||||
|
Type: walletTransactionTypeSubscriptionDebit,
|
||||||
|
Amount: -totalAmount,
|
||||||
|
Currency: model.StringPtr(currency),
|
||||||
|
Note: model.StringPtr(fmt.Sprintf("Subscription payment for %s (%d months)", input.Plan.Name, input.TermMonths)),
|
||||||
|
PaymentID: &paymentRecord.ID,
|
||||||
|
PlanID: &input.Plan.ID,
|
||||||
|
TermMonths: int32Ptr(input.TermMonths),
|
||||||
|
}
|
||||||
|
return tx.Create(debitTransaction).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildPaymentSubscription(input paymentExecutionInput, paymentRecord *model.Payment, totalAmount, topupAmount float64, now, newExpiry time.Time) *model.PlanSubscription {
|
||||||
|
return &model.PlanSubscription{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
UserID: input.UserID,
|
||||||
|
PaymentID: paymentRecord.ID,
|
||||||
|
PlanID: input.Plan.ID,
|
||||||
|
TermMonths: input.TermMonths,
|
||||||
|
PaymentMethod: input.PaymentMethod,
|
||||||
|
WalletAmount: totalAmount,
|
||||||
|
TopupAmount: topupAmount,
|
||||||
|
StartedAt: now,
|
||||||
|
ExpiresAt: newExpiry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSubscriptionNotification(userID, paymentID, invoiceID string, planRecord *model.Plan, subscription *model.PlanSubscription) *model.Notification {
|
||||||
|
return &model.Notification{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
UserID: userID,
|
||||||
|
Type: "billing.subscription",
|
||||||
|
Title: "Subscription activated",
|
||||||
|
Message: fmt.Sprintf("Your subscription to %s is active until %s.", planRecord.Name, subscription.ExpiresAt.UTC().Format("2006-01-02")),
|
||||||
|
Metadata: model.StringPtr(mustMarshalJSON(map[string]any{
|
||||||
|
"payment_id": paymentID,
|
||||||
|
"invoice_id": invoiceID,
|
||||||
|
"plan_id": planRecord.ID,
|
||||||
|
"term_months": subscription.TermMonths,
|
||||||
|
"payment_method": subscription.PaymentMethod,
|
||||||
|
"wallet_amount": subscription.WalletAmount,
|
||||||
|
"topup_amount": subscription.TopupAmount,
|
||||||
|
"plan_expires_at": subscription.ExpiresAt.UTC().Format(time.RFC3339),
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) buildPaymentInvoice(ctx context.Context, paymentRecord *model.Payment) (string, string, error) {
|
||||||
|
details, err := s.loadPaymentInvoiceDetails(ctx, paymentRecord)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
createdAt := formatOptionalTimestamp(paymentRecord.CreatedAt)
|
||||||
|
lines := []string{
|
||||||
|
"Stream API Invoice",
|
||||||
|
fmt.Sprintf("Invoice ID: %s", buildInvoiceID(paymentRecord.ID)),
|
||||||
|
fmt.Sprintf("Payment ID: %s", paymentRecord.ID),
|
||||||
|
fmt.Sprintf("User ID: %s", paymentRecord.UserID),
|
||||||
|
fmt.Sprintf("Plan: %s", details.PlanName),
|
||||||
|
fmt.Sprintf("Amount: %.2f %s", paymentRecord.Amount, normalizeCurrency(paymentRecord.Currency)),
|
||||||
|
fmt.Sprintf("Status: %s", strings.ToUpper(normalizePaymentStatus(paymentRecord.Status))),
|
||||||
|
fmt.Sprintf("Provider: %s", strings.ToUpper(stringValue(paymentRecord.Provider))),
|
||||||
|
fmt.Sprintf("Payment Method: %s", strings.ToUpper(details.PaymentMethod)),
|
||||||
|
fmt.Sprintf("Transaction ID: %s", stringValue(paymentRecord.TransactionID)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if details.TermMonths != nil {
|
||||||
|
lines = append(lines, fmt.Sprintf("Term: %d month(s)", *details.TermMonths))
|
||||||
|
}
|
||||||
|
if details.ExpiresAt != nil {
|
||||||
|
lines = append(lines, fmt.Sprintf("Valid Until: %s", details.ExpiresAt.UTC().Format(time.RFC3339)))
|
||||||
|
}
|
||||||
|
if details.WalletAmount > 0 {
|
||||||
|
lines = append(lines, fmt.Sprintf("Wallet Applied: %.2f %s", details.WalletAmount, normalizeCurrency(paymentRecord.Currency)))
|
||||||
|
}
|
||||||
|
if details.TopupAmount > 0 {
|
||||||
|
lines = append(lines, fmt.Sprintf("Top-up Added: %.2f %s", details.TopupAmount, normalizeCurrency(paymentRecord.Currency)))
|
||||||
|
}
|
||||||
|
lines = append(lines, fmt.Sprintf("Created At: %s", createdAt))
|
||||||
|
|
||||||
|
return strings.Join(lines, "\n"), buildInvoiceFilename(paymentRecord.ID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTopupInvoice(transaction *model.WalletTransaction) string {
|
||||||
|
createdAt := formatOptionalTimestamp(transaction.CreatedAt)
|
||||||
|
return strings.Join([]string{
|
||||||
|
"Stream API Wallet Top-up Invoice",
|
||||||
|
fmt.Sprintf("Invoice ID: %s", buildInvoiceID(transaction.ID)),
|
||||||
|
fmt.Sprintf("Wallet Transaction ID: %s", transaction.ID),
|
||||||
|
fmt.Sprintf("User ID: %s", transaction.UserID),
|
||||||
|
fmt.Sprintf("Amount: %.2f %s", transaction.Amount, normalizeCurrency(transaction.Currency)),
|
||||||
|
"Status: SUCCESS",
|
||||||
|
fmt.Sprintf("Type: %s", strings.ToUpper(transaction.Type)),
|
||||||
|
fmt.Sprintf("Note: %s", model.StringValue(transaction.Note)),
|
||||||
|
fmt.Sprintf("Created At: %s", createdAt),
|
||||||
|
}, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadPaymentInvoiceDetails(ctx context.Context, paymentRecord *model.Payment) (*paymentInvoiceDetails, error) {
|
||||||
|
details := &paymentInvoiceDetails{
|
||||||
|
PlanName: "Unknown plan",
|
||||||
|
PaymentMethod: paymentMethodWallet,
|
||||||
|
}
|
||||||
|
|
||||||
|
if paymentRecord.PlanID != nil && strings.TrimSpace(*paymentRecord.PlanID) != "" {
|
||||||
|
var planRecord model.Plan
|
||||||
|
if err := s.db.WithContext(ctx).Where("id = ?", *paymentRecord.PlanID).First(&planRecord).Error; err != nil {
|
||||||
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
details.PlanName = planRecord.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var subscription model.PlanSubscription
|
||||||
|
if err := s.db.WithContext(ctx).
|
||||||
|
Where("payment_id = ?", paymentRecord.ID).
|
||||||
|
Order("created_at DESC").
|
||||||
|
First(&subscription).Error; err != nil {
|
||||||
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return details, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
details.TermMonths = &subscription.TermMonths
|
||||||
|
details.PaymentMethod = normalizePaymentMethod(subscription.PaymentMethod)
|
||||||
|
if details.PaymentMethod == "" {
|
||||||
|
details.PaymentMethod = paymentMethodWallet
|
||||||
|
}
|
||||||
|
details.ExpiresAt = &subscription.ExpiresAt
|
||||||
|
details.WalletAmount = subscription.WalletAmount
|
||||||
|
details.TopupAmount = subscription.TopupAmount
|
||||||
|
|
||||||
|
return details, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAllowedTermMonths(value int32) bool {
|
||||||
|
_, ok := allowedTermMonths[value]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func lockUserForUpdate(ctx context.Context, tx *gorm.DB, userID string) (*model.User, error) {
|
||||||
|
if tx.Dialector.Name() == "sqlite" {
|
||||||
|
res := tx.WithContext(ctx).Exec("UPDATE user SET id = id WHERE id = ?", userID)
|
||||||
|
if res.Error != nil {
|
||||||
|
return nil, res.Error
|
||||||
|
}
|
||||||
|
if res.RowsAffected == 0 {
|
||||||
|
return nil, gorm.ErrRecordNotFound
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var user model.User
|
||||||
|
if err := tx.WithContext(ctx).
|
||||||
|
Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||||
|
Where("id = ?", userID).
|
||||||
|
First(&user).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func maxFloat(left, right float64) float64 {
|
||||||
|
if left > right {
|
||||||
|
return left
|
||||||
|
}
|
||||||
|
return right
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatOptionalTimestamp(value *time.Time) string {
|
||||||
|
if value == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return value.UTC().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustMarshalJSON(value any) string {
|
||||||
|
encoded, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
return string(encoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func messageResponse(message string) *appv1.MessageResponse {
|
||||||
|
return &appv1.MessageResponse{Message: message}
|
||||||
|
}
|
||||||
570
internal/service/proto_helpers.go
Normal file
570
internal/service/proto_helpers.go
Normal file
@@ -0,0 +1,570 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
appv1 "stream.api/internal/api/proto/app/v1"
|
||||||
|
"stream.api/internal/database/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ensurePaidPlan(user *model.User) error {
|
||||||
|
if user == nil {
|
||||||
|
return status.Error(codes.Unauthenticated, "Unauthorized")
|
||||||
|
}
|
||||||
|
if user.PlanID == nil || strings.TrimSpace(*user.PlanID) == "" {
|
||||||
|
return status.Error(codes.PermissionDenied, adTemplateUpgradeRequiredMessage)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func playerConfigActionAllowed(user *model.User, configCount int64, action string) error {
|
||||||
|
if user == nil {
|
||||||
|
return status.Error(codes.Unauthenticated, "Unauthorized")
|
||||||
|
}
|
||||||
|
if user.PlanID != nil && strings.TrimSpace(*user.PlanID) != "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch action {
|
||||||
|
case "create":
|
||||||
|
if configCount > 0 {
|
||||||
|
return status.Error(codes.FailedPrecondition, playerConfigFreePlanLimitMessage)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case "delete":
|
||||||
|
return nil
|
||||||
|
case "update", "set-default", "toggle-active":
|
||||||
|
if configCount > 1 {
|
||||||
|
return status.Error(codes.FailedPrecondition, playerConfigFreePlanReconciliationMessage)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeRole(role *string) string {
|
||||||
|
if role == nil || strings.TrimSpace(*role) == "" {
|
||||||
|
return "USER"
|
||||||
|
}
|
||||||
|
return *role
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateOAuthState() (string, error) {
|
||||||
|
buffer := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(buffer); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(buffer), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func googleOAuthStateCacheKey(state string) string {
|
||||||
|
return "google_oauth_state:" + state
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringPointerOrNil(value string) *string {
|
||||||
|
trimmed := strings.TrimSpace(value)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &trimmed
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoVideo(item *model.Video, jobID ...string) *appv1.Video {
|
||||||
|
if item == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
statusValue := stringValue(item.Status)
|
||||||
|
if statusValue == "" {
|
||||||
|
statusValue = "ready"
|
||||||
|
}
|
||||||
|
var linkedJobID *string
|
||||||
|
if len(jobID) > 0 {
|
||||||
|
linkedJobID = stringPointerOrNil(jobID[0])
|
||||||
|
}
|
||||||
|
return &appv1.Video{
|
||||||
|
Id: item.ID,
|
||||||
|
UserId: item.UserID,
|
||||||
|
Title: item.Title,
|
||||||
|
Description: item.Description,
|
||||||
|
Url: item.URL,
|
||||||
|
Status: strings.ToLower(statusValue),
|
||||||
|
Size: item.Size,
|
||||||
|
Duration: item.Duration,
|
||||||
|
Format: item.Format,
|
||||||
|
Thumbnail: item.Thumbnail,
|
||||||
|
ProcessingStatus: item.ProcessingStatus,
|
||||||
|
StorageType: item.StorageType,
|
||||||
|
CreatedAt: timeToProto(item.CreatedAt),
|
||||||
|
UpdatedAt: timestamppb.New(item.UpdatedAt.UTC()),
|
||||||
|
JobId: linkedJobID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) buildVideo(ctx context.Context, video *model.Video) (*appv1.Video, error) {
|
||||||
|
if video == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
jobID, err := s.loadLatestVideoJobID(ctx, video.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if jobID != nil {
|
||||||
|
return toProtoVideo(video, *jobID), nil
|
||||||
|
}
|
||||||
|
return toProtoVideo(video), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeVideoStatusValue(value string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||||
|
case "processing", "pending":
|
||||||
|
return "processing"
|
||||||
|
case "failed", "error":
|
||||||
|
return "failed"
|
||||||
|
default:
|
||||||
|
return "ready"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectStorageType(rawURL string) string {
|
||||||
|
if shouldDeleteStoredObject(rawURL) {
|
||||||
|
return "S3"
|
||||||
|
}
|
||||||
|
return "WORKER"
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldDeleteStoredObject(rawURL string) bool {
|
||||||
|
trimmed := strings.TrimSpace(rawURL)
|
||||||
|
if trimmed == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
parsed, err := url.Parse(trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return !strings.HasPrefix(trimmed, "/")
|
||||||
|
}
|
||||||
|
return parsed.Scheme == "" && parsed.Host == "" && !strings.HasPrefix(trimmed, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractObjectKey(rawURL string) string {
|
||||||
|
trimmed := strings.TrimSpace(rawURL)
|
||||||
|
if trimmed == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parsed, err := url.Parse(trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
return strings.TrimPrefix(parsed.Path, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func protoUserFromPayload(user *userPayload) *appv1.User {
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &appv1.User{
|
||||||
|
Id: user.ID,
|
||||||
|
Email: user.Email,
|
||||||
|
Username: user.Username,
|
||||||
|
Avatar: user.Avatar,
|
||||||
|
Role: user.Role,
|
||||||
|
GoogleId: user.GoogleID,
|
||||||
|
StorageUsed: user.StorageUsed,
|
||||||
|
PlanId: user.PlanID,
|
||||||
|
PlanStartedAt: timeToProto(user.PlanStartedAt),
|
||||||
|
PlanExpiresAt: timeToProto(user.PlanExpiresAt),
|
||||||
|
PlanTermMonths: user.PlanTermMonths,
|
||||||
|
PlanPaymentMethod: user.PlanPaymentMethod,
|
||||||
|
PlanExpiringSoon: user.PlanExpiringSoon,
|
||||||
|
WalletBalance: user.WalletBalance,
|
||||||
|
Language: user.Language,
|
||||||
|
Locale: user.Locale,
|
||||||
|
CreatedAt: timeToProto(user.CreatedAt),
|
||||||
|
UpdatedAt: timestamppb.New(user.UpdatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoUser(user *userPayload) *appv1.User {
|
||||||
|
return protoUserFromPayload(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoPreferences(pref *model.UserPreference) *appv1.Preferences {
|
||||||
|
if pref == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &appv1.Preferences{
|
||||||
|
EmailNotifications: boolValue(pref.EmailNotifications),
|
||||||
|
PushNotifications: boolValue(pref.PushNotifications),
|
||||||
|
MarketingNotifications: pref.MarketingNotifications,
|
||||||
|
TelegramNotifications: pref.TelegramNotifications,
|
||||||
|
Language: model.StringValue(pref.Language),
|
||||||
|
Locale: model.StringValue(pref.Locale),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoNotification(item model.Notification) *appv1.Notification {
|
||||||
|
return &appv1.Notification{
|
||||||
|
Id: item.ID,
|
||||||
|
Type: normalizeNotificationType(item.Type),
|
||||||
|
Title: item.Title,
|
||||||
|
Message: item.Message,
|
||||||
|
Read: item.IsRead,
|
||||||
|
ActionUrl: item.ActionURL,
|
||||||
|
ActionLabel: item.ActionLabel,
|
||||||
|
CreatedAt: timeToProto(item.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoDomain(item *model.Domain) *appv1.Domain {
|
||||||
|
if item == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &appv1.Domain{
|
||||||
|
Id: item.ID,
|
||||||
|
Name: item.Name,
|
||||||
|
CreatedAt: timeToProto(item.CreatedAt),
|
||||||
|
UpdatedAt: timeToProto(item.UpdatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoAdTemplate(item *model.AdTemplate) *appv1.AdTemplate {
|
||||||
|
if item == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &appv1.AdTemplate{
|
||||||
|
Id: item.ID,
|
||||||
|
Name: item.Name,
|
||||||
|
Description: item.Description,
|
||||||
|
VastTagUrl: item.VastTagURL,
|
||||||
|
AdFormat: model.StringValue(item.AdFormat),
|
||||||
|
Duration: int64PtrToInt32Ptr(item.Duration),
|
||||||
|
IsActive: boolValue(item.IsActive),
|
||||||
|
IsDefault: item.IsDefault,
|
||||||
|
CreatedAt: timeToProto(item.CreatedAt),
|
||||||
|
UpdatedAt: timeToProto(item.UpdatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoPlayerConfig(item *model.PlayerConfig) *appv1.PlayerConfig {
|
||||||
|
if item == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &appv1.PlayerConfig{
|
||||||
|
Id: item.ID,
|
||||||
|
Name: item.Name,
|
||||||
|
Description: item.Description,
|
||||||
|
Autoplay: item.Autoplay,
|
||||||
|
Loop: item.Loop,
|
||||||
|
Muted: item.Muted,
|
||||||
|
ShowControls: boolValue(item.ShowControls),
|
||||||
|
Pip: boolValue(item.Pip),
|
||||||
|
Airplay: boolValue(item.Airplay),
|
||||||
|
Chromecast: boolValue(item.Chromecast),
|
||||||
|
IsActive: boolValue(item.IsActive),
|
||||||
|
IsDefault: item.IsDefault,
|
||||||
|
CreatedAt: timeToProto(item.CreatedAt),
|
||||||
|
UpdatedAt: timeToProto(&item.UpdatedAt),
|
||||||
|
EncrytionM3U8: boolValue(item.EncrytionM3u8),
|
||||||
|
LogoUrl: nullableTrimmedString(item.LogoURL),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoAdminPlayerConfig(item *model.PlayerConfig, ownerEmail *string) *appv1.AdminPlayerConfig {
|
||||||
|
if item == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &appv1.AdminPlayerConfig{
|
||||||
|
Id: item.ID,
|
||||||
|
UserId: item.UserID,
|
||||||
|
Name: item.Name,
|
||||||
|
Description: item.Description,
|
||||||
|
Autoplay: item.Autoplay,
|
||||||
|
Loop: item.Loop,
|
||||||
|
Muted: item.Muted,
|
||||||
|
ShowControls: boolValue(item.ShowControls),
|
||||||
|
Pip: boolValue(item.Pip),
|
||||||
|
Airplay: boolValue(item.Airplay),
|
||||||
|
Chromecast: boolValue(item.Chromecast),
|
||||||
|
IsActive: boolValue(item.IsActive),
|
||||||
|
IsDefault: item.IsDefault,
|
||||||
|
OwnerEmail: ownerEmail,
|
||||||
|
CreatedAt: timeToProto(item.CreatedAt),
|
||||||
|
UpdatedAt: timeToProto(&item.UpdatedAt),
|
||||||
|
EncrytionM3U8: boolValue(item.EncrytionM3u8),
|
||||||
|
LogoUrl: nullableTrimmedString(item.LogoURL),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoPlan(item *model.Plan) *appv1.Plan {
|
||||||
|
if item == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &appv1.Plan{
|
||||||
|
Id: item.ID,
|
||||||
|
Name: item.Name,
|
||||||
|
Description: item.Description,
|
||||||
|
Price: item.Price,
|
||||||
|
Cycle: item.Cycle,
|
||||||
|
StorageLimit: item.StorageLimit,
|
||||||
|
UploadLimit: item.UploadLimit,
|
||||||
|
DurationLimit: item.DurationLimit,
|
||||||
|
QualityLimit: item.QualityLimit,
|
||||||
|
Features: item.Features,
|
||||||
|
IsActive: boolValue(item.IsActive),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoPayment(item *model.Payment) *appv1.Payment {
|
||||||
|
if item == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &appv1.Payment{
|
||||||
|
Id: item.ID,
|
||||||
|
UserId: item.UserID,
|
||||||
|
PlanId: item.PlanID,
|
||||||
|
Amount: item.Amount,
|
||||||
|
Currency: normalizeCurrency(item.Currency),
|
||||||
|
Status: normalizePaymentStatus(item.Status),
|
||||||
|
Provider: strings.ToUpper(stringValue(item.Provider)),
|
||||||
|
TransactionId: item.TransactionID,
|
||||||
|
CreatedAt: timeToProto(item.CreatedAt),
|
||||||
|
UpdatedAt: timestamppb.New(item.UpdatedAt.UTC()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoPlanSubscription(item *model.PlanSubscription) *appv1.PlanSubscription {
|
||||||
|
if item == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &appv1.PlanSubscription{
|
||||||
|
Id: item.ID,
|
||||||
|
UserId: item.UserID,
|
||||||
|
PaymentId: item.PaymentID,
|
||||||
|
PlanId: item.PlanID,
|
||||||
|
TermMonths: item.TermMonths,
|
||||||
|
PaymentMethod: item.PaymentMethod,
|
||||||
|
WalletAmount: item.WalletAmount,
|
||||||
|
TopupAmount: item.TopupAmount,
|
||||||
|
StartedAt: timestamppb.New(item.StartedAt.UTC()),
|
||||||
|
ExpiresAt: timestamppb.New(item.ExpiresAt.UTC()),
|
||||||
|
CreatedAt: timeToProto(item.CreatedAt),
|
||||||
|
UpdatedAt: timeToProto(item.UpdatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoWalletTransaction(item *model.WalletTransaction) *appv1.WalletTransaction {
|
||||||
|
if item == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &appv1.WalletTransaction{
|
||||||
|
Id: item.ID,
|
||||||
|
UserId: item.UserID,
|
||||||
|
Type: item.Type,
|
||||||
|
Amount: item.Amount,
|
||||||
|
Currency: normalizeCurrency(item.Currency),
|
||||||
|
Note: item.Note,
|
||||||
|
PaymentId: item.PaymentID,
|
||||||
|
PlanId: item.PlanID,
|
||||||
|
TermMonths: item.TermMonths,
|
||||||
|
CreatedAt: timeToProto(item.CreatedAt),
|
||||||
|
UpdatedAt: timeToProto(item.UpdatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func timeToProto(value *time.Time) *timestamppb.Timestamp {
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return timestamppb.New(value.UTC())
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolValue(value *bool) bool {
|
||||||
|
return value != nil && *value
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringValue(value *string) string {
|
||||||
|
if value == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *value
|
||||||
|
}
|
||||||
|
|
||||||
|
func int32PtrToInt64Ptr(value *int32) *int64 {
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
converted := int64(*value)
|
||||||
|
return &converted
|
||||||
|
}
|
||||||
|
|
||||||
|
func int64PtrToInt32Ptr(value *int64) *int32 {
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
converted := int32(*value)
|
||||||
|
return &converted
|
||||||
|
}
|
||||||
|
|
||||||
|
func int32Ptr(value int32) *int32 {
|
||||||
|
return &value
|
||||||
|
}
|
||||||
|
|
||||||
|
func protoStringValue(value *string) string {
|
||||||
|
if value == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(*value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func nullableTrimmedStringPtr(value *string) *string {
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
trimmed := strings.TrimSpace(*value)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &trimmed
|
||||||
|
}
|
||||||
|
|
||||||
|
func nullableTrimmedString(value *string) *string {
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
trimmed := strings.TrimSpace(*value)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &trimmed
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeNotificationType(value string) string {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(value))
|
||||||
|
switch {
|
||||||
|
case strings.Contains(lower, "video"):
|
||||||
|
return "video"
|
||||||
|
case strings.Contains(lower, "payment"), strings.Contains(lower, "billing"):
|
||||||
|
return "payment"
|
||||||
|
case strings.Contains(lower, "warning"):
|
||||||
|
return "warning"
|
||||||
|
case strings.Contains(lower, "error"):
|
||||||
|
return "error"
|
||||||
|
case strings.Contains(lower, "success"):
|
||||||
|
return "success"
|
||||||
|
case strings.Contains(lower, "system"):
|
||||||
|
return "system"
|
||||||
|
default:
|
||||||
|
return "info"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeDomain(value string) string {
|
||||||
|
normalized := strings.TrimSpace(strings.ToLower(value))
|
||||||
|
normalized = strings.TrimPrefix(normalized, "https://")
|
||||||
|
normalized = strings.TrimPrefix(normalized, "http://")
|
||||||
|
normalized = strings.TrimPrefix(normalized, "www.")
|
||||||
|
normalized = strings.TrimSuffix(normalized, "/")
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAdFormat(value string) string {
|
||||||
|
switch strings.TrimSpace(strings.ToLower(value)) {
|
||||||
|
case "mid-roll", "post-roll":
|
||||||
|
return strings.TrimSpace(strings.ToLower(value))
|
||||||
|
default:
|
||||||
|
return "pre-roll"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func adTemplateIsActive(value *bool) bool {
|
||||||
|
return value == nil || *value
|
||||||
|
}
|
||||||
|
|
||||||
|
func playerConfigIsActive(value *bool) bool {
|
||||||
|
return value == nil || *value
|
||||||
|
}
|
||||||
|
|
||||||
|
func unsetDefaultTemplates(tx *gorm.DB, userID, excludeID string) error {
|
||||||
|
query := tx.Model(&model.AdTemplate{}).Where("user_id = ?", userID)
|
||||||
|
if excludeID != "" {
|
||||||
|
query = query.Where("id <> ?", excludeID)
|
||||||
|
}
|
||||||
|
return query.Update("is_default", false).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func unsetDefaultPlayerConfigs(tx *gorm.DB, userID, excludeID string) error {
|
||||||
|
query := tx.Model(&model.PlayerConfig{}).Where("user_id = ?", userID)
|
||||||
|
if excludeID != "" {
|
||||||
|
query = query.Where("id <> ?", excludeID)
|
||||||
|
}
|
||||||
|
return query.Update("is_default", false).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePaymentStatus(status *string) string {
|
||||||
|
value := strings.ToLower(strings.TrimSpace(stringValue(status)))
|
||||||
|
switch value {
|
||||||
|
case "success", "succeeded", "paid":
|
||||||
|
return "success"
|
||||||
|
case "failed", "error", "canceled", "cancelled":
|
||||||
|
return "failed"
|
||||||
|
case "pending", "processing":
|
||||||
|
return "pending"
|
||||||
|
default:
|
||||||
|
if value == "" {
|
||||||
|
return "success"
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeCurrency(currency *string) string {
|
||||||
|
value := strings.ToUpper(strings.TrimSpace(stringValue(currency)))
|
||||||
|
if value == "" {
|
||||||
|
return "USD"
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePaymentMethod(value string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||||
|
case paymentMethodWallet:
|
||||||
|
return paymentMethodWallet
|
||||||
|
case paymentMethodTopup:
|
||||||
|
return paymentMethodTopup
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeOptionalPaymentMethod(value *string) *string {
|
||||||
|
normalized := normalizePaymentMethod(stringValue(value))
|
||||||
|
if normalized == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildInvoiceID(id string) string {
|
||||||
|
trimmed := strings.ReplaceAll(strings.TrimSpace(id), "-", "")
|
||||||
|
if len(trimmed) > 12 {
|
||||||
|
trimmed = trimmed[:12]
|
||||||
|
}
|
||||||
|
return "INV-" + strings.ToUpper(trimmed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTransactionID(prefix string) string {
|
||||||
|
return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildInvoiceFilename(id string) string {
|
||||||
|
return fmt.Sprintf("invoice-%s.txt", id)
|
||||||
|
}
|
||||||
240
internal/service/referral_helpers.go
Normal file
240
internal/service/referral_helpers.go
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"stream.api/internal/database/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func referralUserEligible(user *model.User) bool {
|
||||||
|
if user == nil || user.ReferralEligible == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return *user.ReferralEligible
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveReferralRewardBps(value *int32) int32 {
|
||||||
|
if value == nil {
|
||||||
|
return defaultReferralRewardBps
|
||||||
|
}
|
||||||
|
if *value < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if *value > 10000 {
|
||||||
|
return 10000
|
||||||
|
}
|
||||||
|
return *value
|
||||||
|
}
|
||||||
|
|
||||||
|
func referralRewardBpsToPercent(value int32) float64 {
|
||||||
|
return float64(value) / 100
|
||||||
|
}
|
||||||
|
|
||||||
|
func referralRewardProcessed(user *model.User) bool {
|
||||||
|
if user == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if user.ReferralRewardGrantedAt != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if user.ReferralRewardPaymentID != nil && strings.TrimSpace(*user.ReferralRewardPaymentID) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func sameTrimmedStringFold(left *string, right string) bool {
|
||||||
|
if left == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.EqualFold(strings.TrimSpace(*left), strings.TrimSpace(right))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) buildReferralShareLink(username *string) *string {
|
||||||
|
trimmed := strings.TrimSpace(stringValue(username))
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
path := "/ref/" + url.PathEscape(trimmed)
|
||||||
|
base := strings.TrimRight(strings.TrimSpace(s.frontendBaseURL), "/")
|
||||||
|
if base == "" {
|
||||||
|
return &path
|
||||||
|
}
|
||||||
|
link := base + path
|
||||||
|
return &link
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadReferralUsersByUsername(ctx context.Context, username string) ([]model.User, error) {
|
||||||
|
trimmed := strings.TrimSpace(username)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var users []model.User
|
||||||
|
if err := s.db.WithContext(ctx).
|
||||||
|
Where("LOWER(username) = LOWER(?)", trimmed).
|
||||||
|
Order("created_at ASC, id ASC").
|
||||||
|
Limit(2).
|
||||||
|
Find(&users).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) resolveReferralUserByUsername(ctx context.Context, username string) (*model.User, error) {
|
||||||
|
users, err := s.loadReferralUsersByUsername(ctx, username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(users) != 1 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return &users[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) loadReferralUserByUsernameStrict(ctx context.Context, username string) (*model.User, error) {
|
||||||
|
trimmed := strings.TrimSpace(username)
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "Referral username is required")
|
||||||
|
}
|
||||||
|
users, err := s.loadReferralUsersByUsername(ctx, trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Error(codes.Internal, "Failed to resolve referral user")
|
||||||
|
}
|
||||||
|
if len(users) == 0 {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "Referral user not found")
|
||||||
|
}
|
||||||
|
if len(users) > 1 {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "Referral username is ambiguous")
|
||||||
|
}
|
||||||
|
return &users[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) resolveSignupReferrerID(ctx context.Context, refUsername string, newUsername string) (*string, error) {
|
||||||
|
trimmedRefUsername := strings.TrimSpace(refUsername)
|
||||||
|
if trimmedRefUsername == "" || strings.EqualFold(trimmedRefUsername, strings.TrimSpace(newUsername)) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
referrer, err := s.resolveReferralUserByUsername(ctx, trimmedRefUsername)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if referrer == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return &referrer.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildReferralRewardNotification(userID string, rewardAmount float64, referee *model.User, paymentRecord *model.Payment) *model.Notification {
|
||||||
|
refereeLabel := strings.TrimSpace(referee.Email)
|
||||||
|
if username := strings.TrimSpace(stringValue(referee.Username)); username != "" {
|
||||||
|
refereeLabel = "@" + username
|
||||||
|
}
|
||||||
|
return &model.Notification{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
UserID: userID,
|
||||||
|
Type: "billing.referral_reward",
|
||||||
|
Title: "Referral reward granted",
|
||||||
|
Message: fmt.Sprintf("You received %.2f USD from %s's first subscription.", rewardAmount, refereeLabel),
|
||||||
|
Metadata: model.StringPtr(mustMarshalJSON(map[string]any{
|
||||||
|
"payment_id": paymentRecord.ID,
|
||||||
|
"referee_id": referee.ID,
|
||||||
|
"amount": rewardAmount,
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *appServices) maybeGrantReferralReward(ctx context.Context, tx *gorm.DB, input paymentExecutionInput, paymentRecord *model.Payment, subscription *model.PlanSubscription) (*referralRewardResult, error) {
|
||||||
|
if paymentRecord == nil || subscription == nil || input.Plan == nil {
|
||||||
|
return &referralRewardResult{}, nil
|
||||||
|
}
|
||||||
|
if subscription.PaymentMethod != paymentMethodWallet && subscription.PaymentMethod != paymentMethodTopup {
|
||||||
|
return &referralRewardResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
referee, err := lockUserForUpdate(ctx, tx, input.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if referee.ReferredByUserID == nil || strings.TrimSpace(*referee.ReferredByUserID) == "" {
|
||||||
|
return &referralRewardResult{}, nil
|
||||||
|
}
|
||||||
|
if referralRewardProcessed(referee) {
|
||||||
|
return &referralRewardResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var subscriptionCount int64
|
||||||
|
if err := tx.WithContext(ctx).
|
||||||
|
Model(&model.PlanSubscription{}).
|
||||||
|
Where("user_id = ?", referee.ID).
|
||||||
|
Count(&subscriptionCount).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if subscriptionCount != 1 {
|
||||||
|
return &referralRewardResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
referrer, err := lockUserForUpdate(ctx, tx, strings.TrimSpace(*referee.ReferredByUserID))
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return &referralRewardResult{}, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if referrer.ID == referee.ID || !referralUserEligible(referrer) {
|
||||||
|
return &referralRewardResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bps := effectiveReferralRewardBps(referrer.ReferralRewardBps)
|
||||||
|
if bps <= 0 {
|
||||||
|
return &referralRewardResult{}, nil
|
||||||
|
}
|
||||||
|
baseAmount := input.Plan.Price * float64(input.TermMonths)
|
||||||
|
if baseAmount <= 0 {
|
||||||
|
return &referralRewardResult{}, nil
|
||||||
|
}
|
||||||
|
rewardAmount := baseAmount * float64(bps) / 10000
|
||||||
|
if rewardAmount <= 0 {
|
||||||
|
return &referralRewardResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
currency := normalizeCurrency(paymentRecord.Currency)
|
||||||
|
rewardTransaction := &model.WalletTransaction{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
UserID: referrer.ID,
|
||||||
|
Type: walletTransactionTypeReferralReward,
|
||||||
|
Amount: rewardAmount,
|
||||||
|
Currency: model.StringPtr(currency),
|
||||||
|
Note: model.StringPtr(fmt.Sprintf("Referral reward for %s first subscription", referee.Email)),
|
||||||
|
PaymentID: &paymentRecord.ID,
|
||||||
|
PlanID: &input.Plan.ID,
|
||||||
|
}
|
||||||
|
if err := tx.Create(rewardTransaction).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := tx.Create(buildReferralRewardNotification(referrer.ID, rewardAmount, referee, paymentRecord)).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
updates := map[string]any{
|
||||||
|
"referral_reward_granted_at": now,
|
||||||
|
"referral_reward_payment_id": paymentRecord.ID,
|
||||||
|
"referral_reward_amount": rewardAmount,
|
||||||
|
}
|
||||||
|
if err := tx.WithContext(ctx).Model(&model.User{}).Where("id = ?", referee.ID).Updates(updates).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
referee.ReferralRewardGrantedAt = &now
|
||||||
|
referee.ReferralRewardPaymentID = &paymentRecord.ID
|
||||||
|
referee.ReferralRewardAmount = &rewardAmount
|
||||||
|
return &referralRewardResult{Granted: true, Amount: rewardAmount}, nil
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
appv1 "stream.api/internal/api/proto/app/v1"
|
appv1 "stream.api/internal/api/proto/app/v1"
|
||||||
"stream.api/internal/video"
|
"stream.api/internal/dto"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *appServices) ListAdminJobs(ctx context.Context, req *appv1.ListAdminJobsRequest) (*appv1.ListAdminJobsResponse, error) {
|
func (s *appServices) ListAdminJobs(ctx context.Context, req *appv1.ListAdminJobsRequest) (*appv1.ListAdminJobsResponse, error) {
|
||||||
@@ -28,7 +28,7 @@ func (s *appServices) ListAdminJobs(ctx context.Context, req *appv1.ListAdminJob
|
|||||||
useCursorPagination := req.Cursor != nil || pageSize > 0
|
useCursorPagination := req.Cursor != nil || pageSize > 0
|
||||||
|
|
||||||
var (
|
var (
|
||||||
result *video.PaginatedJobs
|
result *dto.PaginatedJobs
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if useCursorPagination {
|
if useCursorPagination {
|
||||||
@@ -39,7 +39,7 @@ func (s *appServices) ListAdminJobs(ctx context.Context, req *appv1.ListAdminJob
|
|||||||
result, err = s.videoService.ListJobs(ctx, offset, limit)
|
result, err = s.videoService.ListJobs(ctx, offset, limit)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, video.ErrInvalidJobCursor) {
|
if errors.Is(err, ErrInvalidJobCursor) {
|
||||||
return nil, status.Error(codes.InvalidArgument, "Invalid job cursor")
|
return nil, status.Error(codes.InvalidArgument, "Invalid job cursor")
|
||||||
}
|
}
|
||||||
return nil, status.Error(codes.Internal, "Failed to list jobs")
|
return nil, status.Error(codes.Internal, "Failed to list jobs")
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
appv1 "stream.api/internal/api/proto/app/v1"
|
appv1 "stream.api/internal/api/proto/app/v1"
|
||||||
"stream.api/internal/database/model"
|
"stream.api/internal/database/model"
|
||||||
"stream.api/internal/video"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *appServices) GetAdminDashboard(ctx context.Context, _ *appv1.GetAdminDashboardRequest) (*appv1.GetAdminDashboardResponse, error) {
|
func (s *appServices) GetAdminDashboard(ctx context.Context, _ *appv1.GetAdminDashboardRequest) (*appv1.GetAdminDashboardResponse, error) {
|
||||||
@@ -491,7 +490,7 @@ func (s *appServices) CreateAdminVideo(ctx context.Context, req *appv1.CreateAdm
|
|||||||
return nil, status.Error(codes.InvalidArgument, "Size must be greater than or equal to 0")
|
return nil, status.Error(codes.InvalidArgument, "Size must be greater than or equal to 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := s.videoService.CreateVideo(ctx, video.CreateVideoInput{
|
created, err := s.videoService.CreateVideo(ctx, CreateVideoInput{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
Title: title,
|
Title: title,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
@@ -503,11 +502,11 @@ func (s *appServices) CreateAdminVideo(ctx context.Context, req *appv1.CreateAdm
|
|||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, video.ErrUserNotFound):
|
case errors.Is(err, ErrUserNotFound):
|
||||||
return nil, status.Error(codes.InvalidArgument, "User not found")
|
return nil, status.Error(codes.InvalidArgument, "User not found")
|
||||||
case errors.Is(err, video.ErrAdTemplateNotFound):
|
case errors.Is(err, ErrAdTemplateNotFound):
|
||||||
return nil, status.Error(codes.InvalidArgument, "Ad template not found")
|
return nil, status.Error(codes.InvalidArgument, "Ad template not found")
|
||||||
case errors.Is(err, video.ErrJobServiceUnavailable):
|
case errors.Is(err, ErrJobServiceUnavailable):
|
||||||
return nil, status.Error(codes.Unavailable, "Job service is unavailable")
|
return nil, status.Error(codes.Unavailable, "Job service is unavailable")
|
||||||
default:
|
default:
|
||||||
return nil, status.Error(codes.Internal, "Failed to create video")
|
return nil, status.Error(codes.Internal, "Failed to create video")
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"stream.api/internal/config"
|
"stream.api/internal/config"
|
||||||
"stream.api/internal/database/model"
|
"stream.api/internal/database/model"
|
||||||
"stream.api/internal/middleware"
|
"stream.api/internal/middleware"
|
||||||
"stream.api/internal/video"
|
|
||||||
"stream.api/pkg/logger"
|
"stream.api/pkg/logger"
|
||||||
"stream.api/pkg/storage"
|
"stream.api/pkg/storage"
|
||||||
)
|
)
|
||||||
@@ -74,8 +73,8 @@ type appServices struct {
|
|||||||
authenticator *middleware.Authenticator
|
authenticator *middleware.Authenticator
|
||||||
cache *redis.RedisAdapter
|
cache *redis.RedisAdapter
|
||||||
storageProvider storage.Provider
|
storageProvider storage.Provider
|
||||||
videoService *video.Service
|
videoService *Service
|
||||||
agentRuntime video.AgentRuntime
|
agentRuntime AgentRuntime
|
||||||
googleOauth *oauth2.Config
|
googleOauth *oauth2.Config
|
||||||
googleStateTTL time.Duration
|
googleStateTTL time.Duration
|
||||||
googleUserInfoURL string
|
googleUserInfoURL string
|
||||||
@@ -117,7 +116,7 @@ type apiErrorBody struct {
|
|||||||
Data any `json:"data,omitempty"`
|
Data any `json:"data,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServices(c *redis.RedisAdapter, db *gorm.DB, l logger.Logger, cfg *config.Config, videoService *video.Service, agentRuntime video.AgentRuntime) *Services {
|
func NewServices(c *redis.RedisAdapter, db *gorm.DB, l logger.Logger, cfg *config.Config, videoService *Service, agentRuntime AgentRuntime) *Services {
|
||||||
var storageProvider storage.Provider
|
var storageProvider storage.Provider
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
provider, err := storage.NewS3Provider(cfg)
|
provider, err := storage.NewS3Provider(cfg)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package services
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
|||||||
package services
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
"stream.api/internal/database/model"
|
"stream.api/internal/database/model"
|
||||||
"stream.api/internal/database/query"
|
"stream.api/internal/database/query"
|
||||||
"stream.api/internal/video/runtime/domain"
|
"stream.api/internal/dto"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JobQueue interface {
|
type JobQueue interface {
|
||||||
@@ -26,8 +26,8 @@ type LogPubSub interface {
|
|||||||
PublishResource(ctx context.Context, agentID string, data []byte) error
|
PublishResource(ctx context.Context, agentID string, data []byte) error
|
||||||
PublishCancel(ctx context.Context, agentID string, jobID string) error
|
PublishCancel(ctx context.Context, agentID string, jobID string) error
|
||||||
PublishJobUpdate(ctx context.Context, jobID string, status string, videoID string) error
|
PublishJobUpdate(ctx context.Context, jobID string, status string, videoID string) error
|
||||||
Subscribe(ctx context.Context, jobID string) (<-chan domain.LogEntry, error)
|
Subscribe(ctx context.Context, jobID string) (<-chan dto.LogEntry, error)
|
||||||
SubscribeResources(ctx context.Context) (<-chan domain.SystemResource, error)
|
SubscribeResources(ctx context.Context) (<-chan dto.SystemResource, error)
|
||||||
SubscribeCancel(ctx context.Context, agentID string) (<-chan string, error)
|
SubscribeCancel(ctx context.Context, agentID string) (<-chan string, error)
|
||||||
SubscribeJobUpdates(ctx context.Context) (<-chan string, error)
|
SubscribeJobUpdates(ctx context.Context) (<-chan string, error)
|
||||||
}
|
}
|
||||||
@@ -43,50 +43,17 @@ func NewJobService(queue JobQueue, pubsub LogPubSub) *JobService {
|
|||||||
|
|
||||||
var ErrInvalidJobCursor = errors.New("invalid job cursor")
|
var ErrInvalidJobCursor = errors.New("invalid job cursor")
|
||||||
|
|
||||||
const (
|
|
||||||
defaultJobPageSize = 20
|
|
||||||
maxJobPageSize = 100
|
|
||||||
jobCursorVersion = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
type PaginatedJobs struct {
|
|
||||||
Jobs []*model.Job `json:"jobs"`
|
|
||||||
Total int64 `json:"total"`
|
|
||||||
Offset int `json:"offset"`
|
|
||||||
Limit int `json:"limit"`
|
|
||||||
HasMore bool `json:"has_more"`
|
|
||||||
NextCursor string `json:"next_cursor,omitempty"`
|
|
||||||
PageSize int `json:"page_size"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type jobListCursor struct {
|
|
||||||
Version int `json:"v"`
|
|
||||||
CreatedAtUnixNano int64 `json:"created_at_unix_nano"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
AgentID string `json:"agent_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type jobConfigEnvelope struct {
|
|
||||||
Image string `json:"image,omitempty"`
|
|
||||||
Commands []string `json:"commands,omitempty"`
|
|
||||||
Environment map[string]string `json:"environment,omitempty"`
|
|
||||||
Name string `json:"name,omitempty"`
|
|
||||||
UserID string `json:"user_id,omitempty"`
|
|
||||||
VideoID string `json:"video_id,omitempty"`
|
|
||||||
TimeLimit int64 `json:"time_limit,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func strPtr(v string) *string { return &v }
|
func strPtr(v string) *string { return &v }
|
||||||
func int64Ptr(v int64) *int64 { return &v }
|
func int64Ptr(v int64) *int64 { return &v }
|
||||||
func boolPtr(v bool) *bool { return &v }
|
func boolPtr(v bool) *bool { return &v }
|
||||||
func float64Ptr(v float64) *float64 { return &v }
|
func float64Ptr(v float64) *float64 { return &v }
|
||||||
func timePtr(v time.Time) *time.Time { return &v }
|
func timePtr(v time.Time) *time.Time { return &v }
|
||||||
|
|
||||||
func parseJobConfig(raw *string) jobConfigEnvelope {
|
func parseJobConfig(raw *string) dto.JobConfigEnvelope {
|
||||||
if raw == nil || strings.TrimSpace(*raw) == "" {
|
if raw == nil || strings.TrimSpace(*raw) == "" {
|
||||||
return jobConfigEnvelope{}
|
return dto.JobConfigEnvelope{}
|
||||||
}
|
}
|
||||||
var cfg jobConfigEnvelope
|
var cfg dto.JobConfigEnvelope
|
||||||
_ = json.Unmarshal([]byte(*raw), &cfg)
|
_ = json.Unmarshal([]byte(*raw), &cfg)
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
@@ -111,15 +78,15 @@ func encodeJobConfig(raw []byte, name, userID, videoID string, timeLimit int64)
|
|||||||
|
|
||||||
func normalizeJobPageSize(pageSize int) int {
|
func normalizeJobPageSize(pageSize int) int {
|
||||||
if pageSize <= 0 {
|
if pageSize <= 0 {
|
||||||
return defaultJobPageSize
|
return dto.DefaultJobPageSize
|
||||||
}
|
}
|
||||||
if pageSize > maxJobPageSize {
|
if pageSize > dto.MaxJobPageSize {
|
||||||
return maxJobPageSize
|
return dto.MaxJobPageSize
|
||||||
}
|
}
|
||||||
return pageSize
|
return pageSize
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodeJobListCursor(cursor jobListCursor) (string, error) {
|
func encodeJobListCursor(cursor dto.JobListCursor) (string, error) {
|
||||||
payload, err := json.Marshal(cursor)
|
payload, err := json.Marshal(cursor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -127,7 +94,7 @@ func encodeJobListCursor(cursor jobListCursor) (string, error) {
|
|||||||
return base64.RawURLEncoding.EncodeToString(payload), nil
|
return base64.RawURLEncoding.EncodeToString(payload), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeJobListCursor(raw string) (*jobListCursor, error) {
|
func decodeJobListCursor(raw string) (*dto.JobListCursor, error) {
|
||||||
raw = strings.TrimSpace(raw)
|
raw = strings.TrimSpace(raw)
|
||||||
if raw == "" {
|
if raw == "" {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -136,11 +103,11 @@ func decodeJobListCursor(raw string) (*jobListCursor, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidJobCursor
|
return nil, ErrInvalidJobCursor
|
||||||
}
|
}
|
||||||
var cursor jobListCursor
|
var cursor dto.JobListCursor
|
||||||
if err := json.Unmarshal(payload, &cursor); err != nil {
|
if err := json.Unmarshal(payload, &cursor); err != nil {
|
||||||
return nil, ErrInvalidJobCursor
|
return nil, ErrInvalidJobCursor
|
||||||
}
|
}
|
||||||
if cursor.Version != jobCursorVersion || cursor.CreatedAtUnixNano <= 0 || strings.TrimSpace(cursor.ID) == "" {
|
if cursor.Version != dto.JobCursorVersion || cursor.CreatedAtUnixNano <= 0 || strings.TrimSpace(cursor.ID) == "" {
|
||||||
return nil, ErrInvalidJobCursor
|
return nil, ErrInvalidJobCursor
|
||||||
}
|
}
|
||||||
cursor.ID = strings.TrimSpace(cursor.ID)
|
cursor.ID = strings.TrimSpace(cursor.ID)
|
||||||
@@ -152,15 +119,15 @@ func buildJobListCursor(job *model.Job, agentID string) (string, error) {
|
|||||||
if job == nil || job.CreatedAt == nil {
|
if job == nil || job.CreatedAt == nil {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
return encodeJobListCursor(jobListCursor{
|
return encodeJobListCursor(dto.JobListCursor{
|
||||||
Version: jobCursorVersion,
|
Version: dto.JobCursorVersion,
|
||||||
CreatedAtUnixNano: job.CreatedAt.UTC().UnixNano(),
|
CreatedAtUnixNano: job.CreatedAt.UTC().UnixNano(),
|
||||||
ID: strings.TrimSpace(job.ID),
|
ID: strings.TrimSpace(job.ID),
|
||||||
AgentID: strings.TrimSpace(agentID),
|
AgentID: strings.TrimSpace(agentID),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func listJobsByOffset(ctx context.Context, agentID string, offset, limit int) (*PaginatedJobs, error) {
|
func listJobsByOffset(ctx context.Context, agentID string, offset, limit int) (*dto.PaginatedJobs, error) {
|
||||||
if offset < 0 {
|
if offset < 0 {
|
||||||
offset = 0
|
offset = 0
|
||||||
}
|
}
|
||||||
@@ -169,7 +136,7 @@ func listJobsByOffset(ctx context.Context, agentID string, offset, limit int) (*
|
|||||||
if agentID != "" {
|
if agentID != "" {
|
||||||
agentNumeric, err := strconv.ParseInt(agentID, 10, 64)
|
agentNumeric, err := strconv.ParseInt(agentID, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &PaginatedJobs{Jobs: []*model.Job{}, Total: 0, Offset: offset, Limit: limit, PageSize: limit, HasMore: false}, nil
|
return &dto.PaginatedJobs{Jobs: []*model.Job{}, Total: 0, Offset: offset, Limit: limit, PageSize: limit, HasMore: false}, nil
|
||||||
}
|
}
|
||||||
q = q.Where(query.Job.AgentID.Eq(agentNumeric))
|
q = q.Where(query.Job.AgentID.Eq(agentNumeric))
|
||||||
}
|
}
|
||||||
@@ -181,11 +148,11 @@ func listJobsByOffset(ctx context.Context, agentID string, offset, limit int) (*
|
|||||||
for _, job := range jobs {
|
for _, job := range jobs {
|
||||||
items = append(items, job)
|
items = append(items, job)
|
||||||
}
|
}
|
||||||
return &PaginatedJobs{Jobs: items, Total: total, Offset: offset, Limit: limit, PageSize: limit, HasMore: offset+len(items) < int(total)}, nil
|
return &dto.PaginatedJobs{Jobs: items, Total: total, Offset: offset, Limit: limit, PageSize: limit, HasMore: offset+len(items) < int(total)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JobService) CreateJob(ctx context.Context, userID string, videoID string, name string, config []byte, priority int, timeLimit int64) (*model.Job, error) {
|
func (s *JobService) CreateJob(ctx context.Context, userID string, videoID string, name string, config []byte, priority int, timeLimit int64) (*model.Job, error) {
|
||||||
status := string(domain.JobStatusPending)
|
status := string(dto.JobStatusPending)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
job := &model.Job{
|
job := &model.Job{
|
||||||
ID: fmt.Sprintf("job-%d", now.UnixNano()),
|
ID: fmt.Sprintf("job-%d", now.UnixNano()),
|
||||||
@@ -201,10 +168,10 @@ func (s *JobService) CreateJob(ctx context.Context, userID string, videoID strin
|
|||||||
if err := query.Job.WithContext(ctx).Create(job); err != nil {
|
if err := query.Job.WithContext(ctx).Create(job); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := syncVideoStatus(ctx, videoID, domain.JobStatusPending); err != nil {
|
if err := syncVideoStatus(ctx, videoID, dto.JobStatusPending); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// domainJob := toDomainJob(job)
|
// dtoJob := todtoJob(job)
|
||||||
if err := s.queue.Enqueue(ctx, job); err != nil {
|
if err := s.queue.Enqueue(ctx, job); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -212,15 +179,15 @@ func (s *JobService) CreateJob(ctx context.Context, userID string, videoID strin
|
|||||||
return job, nil
|
return job, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JobService) ListJobs(ctx context.Context, offset, limit int) (*PaginatedJobs, error) {
|
func (s *JobService) ListJobs(ctx context.Context, offset, limit int) (*dto.PaginatedJobs, error) {
|
||||||
return listJobsByOffset(ctx, "", offset, limit)
|
return listJobsByOffset(ctx, "", offset, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JobService) ListJobsByAgent(ctx context.Context, agentID string, offset, limit int) (*PaginatedJobs, error) {
|
func (s *JobService) ListJobsByAgent(ctx context.Context, agentID string, offset, limit int) (*dto.PaginatedJobs, error) {
|
||||||
return listJobsByOffset(ctx, strings.TrimSpace(agentID), offset, limit)
|
return listJobsByOffset(ctx, strings.TrimSpace(agentID), offset, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JobService) ListJobsByCursor(ctx context.Context, agentID string, cursor string, pageSize int) (*PaginatedJobs, error) {
|
func (s *JobService) ListJobsByCursor(ctx context.Context, agentID string, cursor string, pageSize int) (*dto.PaginatedJobs, error) {
|
||||||
agentID = strings.TrimSpace(agentID)
|
agentID = strings.TrimSpace(agentID)
|
||||||
pageSize = normalizeJobPageSize(pageSize)
|
pageSize = normalizeJobPageSize(pageSize)
|
||||||
|
|
||||||
@@ -236,7 +203,7 @@ func (s *JobService) ListJobsByCursor(ctx context.Context, agentID string, curso
|
|||||||
if agentID != "" {
|
if agentID != "" {
|
||||||
agentNumeric, err := strconv.ParseInt(agentID, 10, 64)
|
agentNumeric, err := strconv.ParseInt(agentID, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &PaginatedJobs{Jobs: []*model.Job{}, Total: 0, Limit: pageSize, PageSize: pageSize, HasMore: false}, nil
|
return &dto.PaginatedJobs{Jobs: []*model.Job{}, Total: 0, Limit: pageSize, PageSize: pageSize, HasMore: false}, nil
|
||||||
}
|
}
|
||||||
q = q.Where(query.Job.AgentID.Eq(agentNumeric))
|
q = q.Where(query.Job.AgentID.Eq(agentNumeric))
|
||||||
}
|
}
|
||||||
@@ -273,7 +240,7 @@ func (s *JobService) ListJobsByCursor(ctx context.Context, agentID string, curso
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &PaginatedJobs{
|
return &dto.PaginatedJobs{
|
||||||
Jobs: items,
|
Jobs: items,
|
||||||
Total: 0,
|
Total: 0,
|
||||||
Limit: pageSize,
|
Limit: pageSize,
|
||||||
@@ -294,10 +261,10 @@ func (s *JobService) GetJob(ctx context.Context, id string) (*model.Job, error)
|
|||||||
func (s *JobService) GetNextJob(ctx context.Context) (*model.Job, error) {
|
func (s *JobService) GetNextJob(ctx context.Context) (*model.Job, error) {
|
||||||
return s.queue.Dequeue(ctx)
|
return s.queue.Dequeue(ctx)
|
||||||
}
|
}
|
||||||
func (s *JobService) SubscribeSystemResources(ctx context.Context) (<-chan domain.SystemResource, error) {
|
func (s *JobService) SubscribeSystemResources(ctx context.Context) (<-chan dto.SystemResource, error) {
|
||||||
return s.pubsub.SubscribeResources(ctx)
|
return s.pubsub.SubscribeResources(ctx)
|
||||||
}
|
}
|
||||||
func (s *JobService) SubscribeJobLogs(ctx context.Context, jobID string) (<-chan domain.LogEntry, error) {
|
func (s *JobService) SubscribeJobLogs(ctx context.Context, jobID string) (<-chan dto.LogEntry, error) {
|
||||||
return s.pubsub.Subscribe(ctx, jobID)
|
return s.pubsub.Subscribe(ctx, jobID)
|
||||||
}
|
}
|
||||||
func (s *JobService) SubscribeCancel(ctx context.Context, agentID string) (<-chan string, error) {
|
func (s *JobService) SubscribeCancel(ctx context.Context, agentID string) (<-chan string, error) {
|
||||||
@@ -307,7 +274,7 @@ func (s *JobService) SubscribeJobUpdates(ctx context.Context) (<-chan string, er
|
|||||||
return s.pubsub.SubscribeJobUpdates(ctx)
|
return s.pubsub.SubscribeJobUpdates(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *JobService) UpdateJobStatus(ctx context.Context, jobID string, status domain.JobStatus) error {
|
func (s *JobService) UpdateJobStatus(ctx context.Context, jobID string, status dto.JobStatus) error {
|
||||||
job, err := query.Job.WithContext(ctx).Where(query.Job.ID.Eq(jobID)).First()
|
job, err := query.Job.WithContext(ctx).Where(query.Job.ID.Eq(jobID)).First()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -335,7 +302,7 @@ func (s *JobService) AssignJob(ctx context.Context, jobID string, agentID string
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
status := string(domain.JobStatusRunning)
|
status := string(dto.JobStatusRunning)
|
||||||
job.AgentID = &agentNumeric
|
job.AgentID = &agentNumeric
|
||||||
job.Status = &status
|
job.Status = &status
|
||||||
job.UpdatedAt = &now
|
job.UpdatedAt = &now
|
||||||
@@ -343,7 +310,7 @@ func (s *JobService) AssignJob(ctx context.Context, jobID string, agentID string
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
cfg := parseJobConfig(job.Config)
|
cfg := parseJobConfig(job.Config)
|
||||||
if err := syncVideoStatus(ctx, cfg.VideoID, domain.JobStatusRunning); err != nil {
|
if err := syncVideoStatus(ctx, cfg.VideoID, dto.JobStatusRunning); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.pubsub.PublishJobUpdate(ctx, jobID, status, cfg.VideoID)
|
return s.pubsub.PublishJobUpdate(ctx, jobID, status, cfg.VideoID)
|
||||||
@@ -358,11 +325,11 @@ func (s *JobService) CancelJob(ctx context.Context, jobID string) error {
|
|||||||
if job.Status != nil {
|
if job.Status != nil {
|
||||||
currentStatus = *job.Status
|
currentStatus = *job.Status
|
||||||
}
|
}
|
||||||
if currentStatus != string(domain.JobStatusPending) && currentStatus != string(domain.JobStatusRunning) {
|
if currentStatus != string(dto.JobStatusPending) && currentStatus != string(dto.JobStatusRunning) {
|
||||||
return fmt.Errorf("cannot cancel job with status %s", currentStatus)
|
return fmt.Errorf("cannot cancel job with status %s", currentStatus)
|
||||||
}
|
}
|
||||||
cancelled := true
|
cancelled := true
|
||||||
status := string(domain.JobStatusCancelled)
|
status := string(dto.JobStatusCancelled)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
job.Cancelled = &cancelled
|
job.Cancelled = &cancelled
|
||||||
job.Status = &status
|
job.Status = &status
|
||||||
@@ -371,7 +338,7 @@ func (s *JobService) CancelJob(ctx context.Context, jobID string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
cfg := parseJobConfig(job.Config)
|
cfg := parseJobConfig(job.Config)
|
||||||
if err := syncVideoStatus(ctx, cfg.VideoID, domain.JobStatusCancelled); err != nil {
|
if err := syncVideoStatus(ctx, cfg.VideoID, dto.JobStatusCancelled); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_ = s.pubsub.PublishJobUpdate(ctx, jobID, status, cfg.VideoID)
|
_ = s.pubsub.PublishJobUpdate(ctx, jobID, status, cfg.VideoID)
|
||||||
@@ -390,7 +357,7 @@ func (s *JobService) RetryJob(ctx context.Context, jobID string) (*model.Job, er
|
|||||||
if job.Status != nil {
|
if job.Status != nil {
|
||||||
currentStatus = *job.Status
|
currentStatus = *job.Status
|
||||||
}
|
}
|
||||||
if currentStatus != string(domain.JobStatusFailure) && currentStatus != string(domain.JobStatusCancelled) {
|
if currentStatus != string(dto.JobStatusFailure) && currentStatus != string(dto.JobStatusCancelled) {
|
||||||
return nil, fmt.Errorf("cannot retry job with status %s", currentStatus)
|
return nil, fmt.Errorf("cannot retry job with status %s", currentStatus)
|
||||||
}
|
}
|
||||||
currentRetry := int64(0)
|
currentRetry := int64(0)
|
||||||
@@ -404,7 +371,7 @@ func (s *JobService) RetryJob(ctx context.Context, jobID string) (*model.Job, er
|
|||||||
if currentRetry >= maxRetries {
|
if currentRetry >= maxRetries {
|
||||||
return nil, fmt.Errorf("max retries (%d) exceeded", maxRetries)
|
return nil, fmt.Errorf("max retries (%d) exceeded", maxRetries)
|
||||||
}
|
}
|
||||||
pending := string(domain.JobStatusPending)
|
pending := string(dto.JobStatusPending)
|
||||||
cancelled := false
|
cancelled := false
|
||||||
progress := 0.0
|
progress := 0.0
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -418,10 +385,10 @@ func (s *JobService) RetryJob(ctx context.Context, jobID string) (*model.Job, er
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
cfg := parseJobConfig(job.Config)
|
cfg := parseJobConfig(job.Config)
|
||||||
if err := syncVideoStatus(ctx, cfg.VideoID, domain.JobStatusPending); err != nil {
|
if err := syncVideoStatus(ctx, cfg.VideoID, dto.JobStatusPending); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// domainJob := toDomainJob(job)
|
// dtoJob := todtoJob(job)
|
||||||
if err := s.queue.Enqueue(ctx, job); err != nil {
|
if err := s.queue.Enqueue(ctx, job); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -482,7 +449,7 @@ func (s *JobService) ProcessLog(ctx context.Context, jobID string, logData []byt
|
|||||||
return s.pubsub.Publish(ctx, jobID, line, progress)
|
return s.pubsub.Publish(ctx, jobID, line, progress)
|
||||||
}
|
}
|
||||||
|
|
||||||
func syncVideoStatus(ctx context.Context, videoID string, status domain.JobStatus) error {
|
func syncVideoStatus(ctx context.Context, videoID string, status dto.JobStatus) error {
|
||||||
videoID = strings.TrimSpace(videoID)
|
videoID = strings.TrimSpace(videoID)
|
||||||
if videoID == "" {
|
if videoID == "" {
|
||||||
return nil
|
return nil
|
||||||
@@ -491,10 +458,10 @@ func syncVideoStatus(ctx context.Context, videoID string, status domain.JobStatu
|
|||||||
statusValue := "processing"
|
statusValue := "processing"
|
||||||
processingStatus := "PROCESSING"
|
processingStatus := "PROCESSING"
|
||||||
switch status {
|
switch status {
|
||||||
case domain.JobStatusSuccess:
|
case dto.JobStatusSuccess:
|
||||||
statusValue = "ready"
|
statusValue = "ready"
|
||||||
processingStatus = "READY"
|
processingStatus = "READY"
|
||||||
case domain.JobStatusFailure, domain.JobStatusCancelled:
|
case dto.JobStatusFailure, dto.JobStatusCancelled:
|
||||||
statusValue = "failed"
|
statusValue = "failed"
|
||||||
processingStatus = "FAILED"
|
processingStatus = "FAILED"
|
||||||
}
|
}
|
||||||
11
internal/service/service_misc_helpers.go
Normal file
11
internal/service/service_misc_helpers.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"stream.api/internal/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *appServices) authenticate(ctx context.Context) (*middleware.AuthResult, error) {
|
||||||
|
return s.authenticator.Authenticate(ctx)
|
||||||
|
}
|
||||||
@@ -1,18 +1,22 @@
|
|||||||
package video
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"stream.api/internal/database/model"
|
"stream.api/internal/database/model"
|
||||||
"stream.api/internal/video/runtime/services"
|
"stream.api/internal/dto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type AgentRuntime interface {
|
||||||
|
ListAgentsWithStats() []*dto.AgentWithStats
|
||||||
|
SendCommand(agentID string, cmd string) bool
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUserNotFound = errors.New("user not found")
|
ErrUserNotFound = errors.New("user not found")
|
||||||
ErrAdTemplateNotFound = errors.New("ad template not found")
|
ErrAdTemplateNotFound = errors.New("ad template not found")
|
||||||
@@ -21,7 +25,7 @@ var (
|
|||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
jobService *services.JobService
|
jobService *JobService
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateVideoInput struct {
|
type CreateVideoInput struct {
|
||||||
@@ -40,11 +44,11 @@ type CreateVideoResult struct {
|
|||||||
Job model.Job
|
Job model.Job
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(db *gorm.DB, jobService *services.JobService) *Service {
|
func NewService(db *gorm.DB, jobService *JobService) *Service {
|
||||||
return &Service{db: db, jobService: jobService}
|
return &Service{db: db, jobService: jobService}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) JobService() *services.JobService {
|
func (s *Service) JobService() *JobService {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -122,21 +126,21 @@ func (s *Service) CreateVideo(ctx context.Context, input CreateVideoInput) (*Cre
|
|||||||
return &CreateVideoResult{Video: video, Job: *job}, nil
|
return &CreateVideoResult{Video: video, Job: *job}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) ListJobs(ctx context.Context, offset, limit int) (*PaginatedJobs, error) {
|
func (s *Service) ListJobs(ctx context.Context, offset, limit int) (*dto.PaginatedJobs, error) {
|
||||||
if s == nil || s.jobService == nil {
|
if s == nil || s.jobService == nil {
|
||||||
return nil, ErrJobServiceUnavailable
|
return nil, ErrJobServiceUnavailable
|
||||||
}
|
}
|
||||||
return s.jobService.ListJobs(ctx, offset, limit)
|
return s.jobService.ListJobs(ctx, offset, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) ListJobsByAgent(ctx context.Context, agentID string, offset, limit int) (*PaginatedJobs, error) {
|
func (s *Service) ListJobsByAgent(ctx context.Context, agentID string, offset, limit int) (*dto.PaginatedJobs, error) {
|
||||||
if s == nil || s.jobService == nil {
|
if s == nil || s.jobService == nil {
|
||||||
return nil, ErrJobServiceUnavailable
|
return nil, ErrJobServiceUnavailable
|
||||||
}
|
}
|
||||||
return s.jobService.ListJobsByAgent(ctx, agentID, offset, limit)
|
return s.jobService.ListJobsByAgent(ctx, agentID, offset, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) ListJobsByCursor(ctx context.Context, agentID string, cursor string, pageSize int) (*PaginatedJobs, error) {
|
func (s *Service) ListJobsByCursor(ctx context.Context, agentID string, cursor string, pageSize int) (*dto.PaginatedJobs, error) {
|
||||||
if s == nil || s.jobService == nil {
|
if s == nil || s.jobService == nil {
|
||||||
return nil, ErrJobServiceUnavailable
|
return nil, ErrJobServiceUnavailable
|
||||||
}
|
}
|
||||||
@@ -219,33 +223,3 @@ func markVideoJobFailed(ctx context.Context, db *gorm.DB, videoID string) error
|
|||||||
Where("id = ?", strings.TrimSpace(videoID)).
|
Where("id = ?", strings.TrimSpace(videoID)).
|
||||||
Updates(map[string]any{"status": "failed", "processing_status": "FAILED"}).Error
|
Updates(map[string]any{"status": "failed", "processing_status": "FAILED"}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func detectStorageType(rawURL string) string {
|
|
||||||
if shouldDeleteStoredObject(rawURL) {
|
|
||||||
return "S3"
|
|
||||||
}
|
|
||||||
return "WORKER"
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldDeleteStoredObject(rawURL string) bool {
|
|
||||||
trimmed := strings.TrimSpace(rawURL)
|
|
||||||
if trimmed == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
parsed, err := url.Parse(trimmed)
|
|
||||||
if err != nil {
|
|
||||||
return !strings.HasPrefix(trimmed, "/")
|
|
||||||
}
|
|
||||||
return parsed.Scheme == "" && parsed.Host == "" && !strings.HasPrefix(trimmed, "/")
|
|
||||||
}
|
|
||||||
|
|
||||||
func nullableTrimmedString(value *string) *string {
|
|
||||||
if value == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
trimmed := strings.TrimSpace(*value)
|
|
||||||
if trimmed == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &trimmed
|
|
||||||
}
|
|
||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
appv1 "stream.api/internal/api/proto/app/v1"
|
appv1 "stream.api/internal/api/proto/app/v1"
|
||||||
"stream.api/internal/database/model"
|
"stream.api/internal/database/model"
|
||||||
"stream.api/internal/video"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *appServices) GetUploadUrl(ctx context.Context, req *appv1.GetUploadUrlRequest) (*appv1.GetUploadUrlResponse, error) {
|
func (s *appServices) GetUploadUrl(ctx context.Context, req *appv1.GetUploadUrlRequest) (*appv1.GetUploadUrlResponse, error) {
|
||||||
@@ -59,7 +58,7 @@ func (s *appServices) CreateVideo(ctx context.Context, req *appv1.CreateVideoReq
|
|||||||
}
|
}
|
||||||
description := strings.TrimSpace(req.GetDescription())
|
description := strings.TrimSpace(req.GetDescription())
|
||||||
|
|
||||||
created, err := s.videoService.CreateVideo(ctx, video.CreateVideoInput{
|
created, err := s.videoService.CreateVideo(ctx, CreateVideoInput{
|
||||||
UserID: result.UserID,
|
UserID: result.UserID,
|
||||||
Title: title,
|
Title: title,
|
||||||
Description: &description,
|
Description: &description,
|
||||||
@@ -71,7 +70,7 @@ func (s *appServices) CreateVideo(ctx context.Context, req *appv1.CreateVideoReq
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("Failed to create video", "error", err)
|
s.logger.Error("Failed to create video", "error", err)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, video.ErrJobServiceUnavailable):
|
case errors.Is(err, ErrJobServiceUnavailable):
|
||||||
return nil, status.Error(codes.Unavailable, "Job service is unavailable")
|
return nil, status.Error(codes.Unavailable, "Job service is unavailable")
|
||||||
default:
|
default:
|
||||||
return nil, status.Error(codes.Internal, "Failed to create video")
|
return nil, status.Error(codes.Internal, "Failed to create video")
|
||||||
|
|||||||
62
internal/transport/grpc/agent_lifecycle.go
Normal file
62
internal/transport/grpc/agent_lifecycle.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
proto "stream.api/internal/api/proto/agent/v1"
|
||||||
|
"stream.api/internal/dto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) RegisterAgent(ctx context.Context, req *proto.RegisterAgentRequest) (*proto.RegisterAgentResponse, error) {
|
||||||
|
if req.Info == nil {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "connection info is required")
|
||||||
|
}
|
||||||
|
id, _, ok := s.getAgentIDFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "invalid session")
|
||||||
|
}
|
||||||
|
hostname := ""
|
||||||
|
if req.Info.CustomLabels != nil {
|
||||||
|
hostname = req.Info.CustomLabels["hostname"]
|
||||||
|
}
|
||||||
|
name := hostname
|
||||||
|
if name == "" {
|
||||||
|
name = fmt.Sprintf("agent-%s", id)
|
||||||
|
}
|
||||||
|
s.agentManager.Register(id, name, req.Info.Platform, req.Info.Backend, req.Info.Version, req.Info.Capacity)
|
||||||
|
if s.onAgentEvent != nil {
|
||||||
|
s.onAgentEvent("agent_update", s.getAgentWithStats(id))
|
||||||
|
}
|
||||||
|
return &proto.RegisterAgentResponse{AgentId: id}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) UnregisterAgent(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) {
|
||||||
|
agentID, token, ok := s.getAgentIDFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "invalid session")
|
||||||
|
}
|
||||||
|
for _, jobID := range s.getAgentJobs(agentID) {
|
||||||
|
_ = s.jobService.UpdateJobStatus(ctx, jobID, dto.JobStatusFailure)
|
||||||
|
s.untrackJobAssignment(agentID, jobID)
|
||||||
|
}
|
||||||
|
s.sessions.Delete(token)
|
||||||
|
s.agentJobs.Delete(agentID)
|
||||||
|
agent := s.getAgentWithStats(agentID)
|
||||||
|
s.agentManager.Unregister(agentID)
|
||||||
|
if s.onAgentEvent != nil {
|
||||||
|
s.onAgentEvent("agent_update", agent)
|
||||||
|
}
|
||||||
|
return &proto.Empty{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) ReportHealth(ctx context.Context, _ *proto.ReportHealthRequest) (*proto.Empty, error) {
|
||||||
|
agentID, _, ok := s.getAgentIDFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "invalid session")
|
||||||
|
}
|
||||||
|
s.agentManager.UpdateHeartbeat(agentID)
|
||||||
|
return &proto.Empty{}, nil
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"stream.api/internal/video/runtime/domain"
|
"stream.api/internal/dto"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AgentInfo struct {
|
type AgentInfo struct {
|
||||||
@@ -95,17 +95,17 @@ func (am *AgentManager) Unregister(id string) {
|
|||||||
delete(am.agents, id)
|
delete(am.agents, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *AgentManager) ListAll() []*domain.Agent {
|
func (am *AgentManager) ListAll() []*dto.Agent {
|
||||||
am.mu.RLock()
|
am.mu.RLock()
|
||||||
defer am.mu.RUnlock()
|
defer am.mu.RUnlock()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
all := make([]*domain.Agent, 0, len(am.agents))
|
all := make([]*dto.Agent, 0, len(am.agents))
|
||||||
for _, info := range am.agents {
|
for _, info := range am.agents {
|
||||||
status := domain.AgentStatusOnline
|
status := dto.AgentStatusOnline
|
||||||
if now.Sub(info.LastHeartbeat) >= 60*time.Second {
|
if now.Sub(info.LastHeartbeat) >= 60*time.Second {
|
||||||
status = domain.AgentStatusOffline
|
status = dto.AgentStatusOffline
|
||||||
}
|
}
|
||||||
all = append(all, &domain.Agent{ID: info.ID, Name: info.Name, Platform: info.Platform, Backend: info.Backend, Version: info.Version, Capacity: info.Capacity, Status: status, CPU: info.CPU, RAM: info.RAM, LastHeartbeat: info.LastHeartbeat, CreatedAt: info.ConnectedAt, UpdatedAt: info.LastHeartbeat})
|
all = append(all, &dto.Agent{ID: info.ID, Name: info.Name, Platform: info.Platform, Backend: info.Backend, Version: info.Version, Capacity: info.Capacity, Status: status, CPU: info.CPU, RAM: info.RAM, LastHeartbeat: info.LastHeartbeat, CreatedAt: info.ConnectedAt, UpdatedAt: info.LastHeartbeat})
|
||||||
}
|
}
|
||||||
return all
|
return all
|
||||||
}
|
}
|
||||||
63
internal/transport/grpc/agent_runtime_server.go
Normal file
63
internal/transport/grpc/agent_runtime_server.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
grpcpkg "google.golang.org/grpc"
|
||||||
|
proto "stream.api/internal/api/proto/agent/v1"
|
||||||
|
"stream.api/internal/dto"
|
||||||
|
"stream.api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
proto.UnimplementedWoodpeckerServer
|
||||||
|
proto.UnimplementedWoodpeckerAuthServer
|
||||||
|
jobService *service.JobService
|
||||||
|
agentManager *AgentManager
|
||||||
|
agentSecret string
|
||||||
|
sessions sync.Map
|
||||||
|
agentJobs sync.Map
|
||||||
|
onAgentEvent func(string, *dto.AgentWithStats)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer(jobService *service.JobService, agentSecret string) *Server {
|
||||||
|
return &Server{jobService: jobService, agentManager: NewAgentManager(), agentSecret: agentSecret}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) SetAgentEventHandler(handler func(string, *dto.AgentWithStats)) {
|
||||||
|
s.onAgentEvent = handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Register(grpcServer grpcpkg.ServiceRegistrar) {
|
||||||
|
proto.RegisterWoodpeckerServer(grpcServer, s)
|
||||||
|
proto.RegisterWoodpeckerAuthServer(grpcServer, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) SendCommand(agentID string, cmd string) bool {
|
||||||
|
return s.agentManager.SendCommand(agentID, cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) ListAgents() []*dto.Agent { return s.agentManager.ListAll() }
|
||||||
|
|
||||||
|
func (s *Server) ListAgentsWithStats() []*dto.AgentWithStats {
|
||||||
|
agents := s.agentManager.ListAll()
|
||||||
|
result := make([]*dto.AgentWithStats, 0, len(agents))
|
||||||
|
for _, agent := range agents {
|
||||||
|
result = append(result, &dto.AgentWithStats{Agent: agent, ActiveJobCount: int64(len(s.getAgentJobs(agent.ID)))})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getAgentWithStats(agentID string) *dto.AgentWithStats {
|
||||||
|
for _, agent := range s.ListAgentsWithStats() {
|
||||||
|
if agent != nil && agent.Agent != nil && agent.Agent.ID == agentID {
|
||||||
|
return agent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Version(context.Context, *proto.Empty) (*proto.VersionResponse, error) {
|
||||||
|
return &proto.VersionResponse{GrpcVersion: 15, ServerVersion: "stream.api"}, nil
|
||||||
|
}
|
||||||
43
internal/transport/grpc/assignments.go
Normal file
43
internal/transport/grpc/assignments.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
func (s *Server) trackJobAssignment(agentID, jobID string) {
|
||||||
|
jobSetInterface, _ := s.agentJobs.LoadOrStore(agentID, &sync.Map{})
|
||||||
|
if jobSet, ok := jobSetInterface.(*sync.Map); ok {
|
||||||
|
jobSet.Store(jobID, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) untrackJobAssignment(agentID, jobID string) {
|
||||||
|
if jobSetInterface, ok := s.agentJobs.Load(agentID); ok {
|
||||||
|
if jobSet, ok := jobSetInterface.(*sync.Map); ok {
|
||||||
|
jobSet.Delete(jobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) isJobAssigned(agentID, jobID string) bool {
|
||||||
|
if jobSetInterface, ok := s.agentJobs.Load(agentID); ok {
|
||||||
|
if jobSet, ok := jobSetInterface.(*sync.Map); ok {
|
||||||
|
_, found := jobSet.Load(jobID)
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getAgentJobs(agentID string) []string {
|
||||||
|
jobs := []string{}
|
||||||
|
if jobSetInterface, ok := s.agentJobs.Load(agentID); ok {
|
||||||
|
if jobSet, ok := jobSetInterface.(*sync.Map); ok {
|
||||||
|
jobSet.Range(func(key, _ any) bool {
|
||||||
|
if jobID, ok := key.(string); ok {
|
||||||
|
jobs = append(jobs, jobID)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jobs
|
||||||
|
}
|
||||||
56
internal/transport/grpc/auth.go
Normal file
56
internal/transport/grpc/auth.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
proto "stream.api/internal/api/proto/agent/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
func generateToken() string {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateAgentID() string {
|
||||||
|
return strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getAgentIDFromContext(ctx context.Context) (string, string, bool) {
|
||||||
|
md, ok := metadata.FromIncomingContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
tokens := md.Get("token")
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
token := tokens[0]
|
||||||
|
if id, ok := s.sessions.Load(token); ok {
|
||||||
|
return id.(string), token, true
|
||||||
|
}
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Auth(ctx context.Context, req *proto.AuthRequest) (*proto.AuthResponse, error) {
|
||||||
|
if s.agentSecret != "" && req.AgentToken != s.agentSecret {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "invalid agent secret")
|
||||||
|
}
|
||||||
|
agentID := req.AgentId
|
||||||
|
if len(agentID) > 6 && agentID[:6] == "agent-" {
|
||||||
|
agentID = agentID[6:]
|
||||||
|
}
|
||||||
|
if agentID == "" {
|
||||||
|
agentID = generateAgentID()
|
||||||
|
}
|
||||||
|
accessToken := generateToken()
|
||||||
|
s.sessions.Store(accessToken, agentID)
|
||||||
|
return &proto.AuthResponse{Status: "ok", AgentId: agentID, AccessToken: accessToken}, nil
|
||||||
|
}
|
||||||
@@ -8,44 +8,39 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
redisadapter "stream.api/internal/adapters/redis"
|
redisadapter "stream.api/internal/adapters/redis"
|
||||||
"stream.api/internal/config"
|
"stream.api/internal/config"
|
||||||
|
"stream.api/internal/dto"
|
||||||
"stream.api/internal/service"
|
"stream.api/internal/service"
|
||||||
"stream.api/internal/video"
|
"stream.api/internal/transport/mqtt"
|
||||||
runtime "stream.api/internal/video/runtime"
|
|
||||||
runtimegrpc "stream.api/internal/video/runtime/grpc"
|
|
||||||
"stream.api/internal/video/runtime/services"
|
|
||||||
"stream.api/pkg/logger"
|
"stream.api/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GRPCModule struct {
|
type GRPCModule struct {
|
||||||
jobService *services.JobService
|
jobService *service.JobService
|
||||||
healthService *services.HealthService
|
agentRuntime *Server
|
||||||
agentRuntime *runtimegrpc.Server
|
mqttPublisher *mqtt.MQTTBootstrap
|
||||||
mqttPublisher *runtime.MQTTBootstrap
|
|
||||||
grpcServer *grpcpkg.Server
|
grpcServer *grpcpkg.Server
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGRPCModule(ctx context.Context, cfg *config.Config, db *gorm.DB, rds *redisadapter.RedisAdapter, appLogger logger.Logger) (*GRPCModule, error) {
|
func NewGRPCModule(ctx context.Context, cfg *config.Config, db *gorm.DB, rds *redisadapter.RedisAdapter, appLogger logger.Logger) (*GRPCModule, error) {
|
||||||
jobService := services.NewJobService(rds, rds)
|
jobService := service.NewJobService(rds, rds)
|
||||||
healthService := services.NewHealthService(db, rds.Client(), cfg.Render.ServiceName)
|
agentRuntime := NewServer(jobService, cfg.Render.AgentSecret)
|
||||||
agentRuntime := runtimegrpc.NewServer(jobService, cfg.Render.AgentSecret)
|
videoService := service.NewService(db, jobService)
|
||||||
videoService := video.NewService(db, jobService)
|
|
||||||
grpcServer := grpcpkg.NewServer()
|
grpcServer := grpcpkg.NewServer()
|
||||||
|
|
||||||
module := &GRPCModule{
|
module := &GRPCModule{
|
||||||
jobService: jobService,
|
jobService: jobService,
|
||||||
healthService: healthService,
|
agentRuntime: agentRuntime,
|
||||||
agentRuntime: agentRuntime,
|
grpcServer: grpcServer,
|
||||||
grpcServer: grpcServer,
|
cfg: cfg,
|
||||||
cfg: cfg,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if publisher, err := runtime.NewMQTTBootstrap(jobService, agentRuntime, appLogger); err != nil {
|
if publisher, err := mqtt.NewMQTTBootstrap(jobService, agentRuntime, appLogger); err != nil {
|
||||||
appLogger.Error("Failed to initialize MQTT publisher", "error", err)
|
appLogger.Error("Failed to initialize MQTT publisher", "error", err)
|
||||||
} else {
|
} else {
|
||||||
module.mqttPublisher = publisher
|
module.mqttPublisher = publisher
|
||||||
agentRuntime.SetAgentEventHandler(func(eventType string, agent *services.AgentWithStats) {
|
agentRuntime.SetAgentEventHandler(func(eventType string, agent *dto.AgentWithStats) {
|
||||||
runtime.PublishAgentMQTTEvent(publisher.Client(), appLogger, eventType, agent)
|
mqtt.PublishAgentMQTTEvent(publisher.Client(), appLogger, eventType, agent)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,8 +53,8 @@ func NewGRPCModule(ctx context.Context, cfg *config.Config, db *gorm.DB, rds *re
|
|||||||
return module, nil
|
return module, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *GRPCModule) JobService() *services.JobService { return m.jobService }
|
func (m *GRPCModule) JobService() *service.JobService { return m.jobService }
|
||||||
func (m *GRPCModule) AgentRuntime() *runtimegrpc.Server { return m.agentRuntime }
|
func (m *GRPCModule) AgentRuntime() *Server { return m.agentRuntime }
|
||||||
func (m *GRPCModule) GRPCServer() *grpcpkg.Server { return m.grpcServer }
|
func (m *GRPCModule) GRPCServer() *grpcpkg.Server { return m.grpcServer }
|
||||||
func (m *GRPCModule) GRPCAddress() string { return ":" + m.cfg.Server.GRPCPort }
|
func (m *GRPCModule) GRPCAddress() string { return ":" + m.cfg.Server.GRPCPort }
|
||||||
func (m *GRPCModule) ServeGRPC(listener net.Listener) error { return m.grpcServer.Serve(listener) }
|
func (m *GRPCModule) ServeGRPC(listener net.Listener) error { return m.grpcServer.Serve(listener) }
|
||||||
|
|||||||
170
internal/transport/grpc/stream_handlers.go
Normal file
170
internal/transport/grpc/stream_handlers.go
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
grpcpkg "google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
proto "stream.api/internal/api/proto/agent/v1"
|
||||||
|
"stream.api/internal/dto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) Next(context.Context, *proto.NextRequest) (*proto.NextResponse, error) {
|
||||||
|
return nil, status.Error(codes.Unimplemented, "use StreamJobs")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) StreamJobs(_ *proto.StreamOptions, stream grpcpkg.ServerStreamingServer[proto.Workflow]) error {
|
||||||
|
ctx := stream.Context()
|
||||||
|
agentID, _, ok := s.getAgentIDFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return status.Error(codes.Unauthenticated, "invalid or missing token")
|
||||||
|
}
|
||||||
|
s.agentManager.UpdateHeartbeat(agentID)
|
||||||
|
cancelCh, _ := s.jobService.SubscribeCancel(ctx, agentID)
|
||||||
|
commandCh, _ := s.agentManager.GetCommandChannel(agentID)
|
||||||
|
ticker := time.NewTicker(2 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case cmd := <-commandCh:
|
||||||
|
payload, _ := json.Marshal(map[string]any{"image": "alpine", "commands": []string{"echo 'System Command'"}, "environment": map[string]string{}, "action": cmd})
|
||||||
|
if err := stream.Send(&proto.Workflow{Id: fmt.Sprintf("cmd-%s-%d", agentID, time.Now().UnixNano()), Timeout: 300, Payload: payload}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case jobID := <-cancelCh:
|
||||||
|
if s.isJobAssigned(agentID, jobID) {
|
||||||
|
if err := stream.Send(&proto.Workflow{Id: jobID, Cancel: true}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
case <-ticker.C:
|
||||||
|
s.agentManager.UpdateHeartbeat(agentID)
|
||||||
|
jobCtx, cancel := context.WithTimeout(ctx, time.Second)
|
||||||
|
job, err := s.jobService.GetNextJob(jobCtx)
|
||||||
|
cancel()
|
||||||
|
if err != nil || job == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.trackJobAssignment(agentID, job.ID)
|
||||||
|
if err := s.jobService.AssignJob(ctx, job.ID, agentID); err != nil {
|
||||||
|
s.untrackJobAssignment(agentID, job.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var config map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(*job.Config), &config); err != nil {
|
||||||
|
_ = s.jobService.UpdateJobStatus(ctx, job.ID, dto.JobStatusFailure)
|
||||||
|
s.untrackJobAssignment(agentID, job.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
image, _ := config["image"].(string)
|
||||||
|
if image == "" {
|
||||||
|
image = "alpine"
|
||||||
|
}
|
||||||
|
commands := []string{"echo 'No commands specified'"}
|
||||||
|
if raw, ok := config["commands"].([]any); ok && len(raw) > 0 {
|
||||||
|
commands = commands[:0]
|
||||||
|
for _, item := range raw {
|
||||||
|
if text, ok := item.(string); ok {
|
||||||
|
commands = append(commands, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(commands) == 0 {
|
||||||
|
commands = []string{"echo 'No commands specified'"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
payload, _ := json.Marshal(map[string]any{"image": image, "commands": commands, "environment": map[string]string{}})
|
||||||
|
if err := stream.Send(&proto.Workflow{Id: job.ID, Timeout: 60 * 60 * 1000, Payload: payload}); err != nil {
|
||||||
|
_ = s.jobService.UpdateJobStatus(ctx, job.ID, dto.JobStatusPending)
|
||||||
|
s.untrackJobAssignment(agentID, job.ID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) SubmitStatus(stream grpcpkg.ClientStreamingServer[proto.StatusUpdate, proto.Empty]) error {
|
||||||
|
ctx := stream.Context()
|
||||||
|
agentID, _, ok := s.getAgentIDFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return status.Error(codes.Unauthenticated, "invalid or missing token")
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
update, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
return stream.SendAndClose(&proto.Empty{})
|
||||||
|
}
|
||||||
|
switch update.Type {
|
||||||
|
case 0, 1:
|
||||||
|
_ = s.jobService.ProcessLog(ctx, update.StepUuid, update.Data)
|
||||||
|
case 4:
|
||||||
|
var progress float64
|
||||||
|
fmt.Sscanf(string(update.Data), "%f", &progress)
|
||||||
|
_ = s.jobService.UpdateJobProgress(ctx, update.StepUuid, progress)
|
||||||
|
case 5:
|
||||||
|
var stats struct {
|
||||||
|
CPU float64 `json:"cpu"`
|
||||||
|
RAM float64 `json:"ram"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(update.Data, &stats) == nil {
|
||||||
|
s.agentManager.UpdateResources(agentID, stats.CPU, stats.RAM)
|
||||||
|
if s.onAgentEvent != nil {
|
||||||
|
s.onAgentEvent("agent_update", s.getAgentWithStats(agentID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = s.jobService.PublishSystemResources(ctx, agentID, update.Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Init(ctx context.Context, req *proto.InitRequest) (*proto.Empty, error) {
|
||||||
|
if err := s.jobService.UpdateJobStatus(ctx, req.Id, dto.JobStatusRunning); err != nil {
|
||||||
|
return nil, status.Error(codes.Internal, "failed to update job status")
|
||||||
|
}
|
||||||
|
return &proto.Empty{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Wait(context.Context, *proto.WaitRequest) (*proto.WaitResponse, error) {
|
||||||
|
return &proto.WaitResponse{Canceled: false}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Done(ctx context.Context, req *proto.DoneRequest) (*proto.Empty, error) {
|
||||||
|
agentID, _, ok := s.getAgentIDFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "invalid session")
|
||||||
|
}
|
||||||
|
jobStatus := dto.JobStatusSuccess
|
||||||
|
if req.State != nil && req.State.Error != "" {
|
||||||
|
jobStatus = dto.JobStatusFailure
|
||||||
|
}
|
||||||
|
if err := s.jobService.UpdateJobStatus(ctx, req.Id, jobStatus); err != nil {
|
||||||
|
return nil, status.Error(codes.Internal, "failed to update job status")
|
||||||
|
}
|
||||||
|
s.untrackJobAssignment(agentID, req.Id)
|
||||||
|
return &proto.Empty{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Update(context.Context, *proto.UpdateRequest) (*proto.Empty, error) {
|
||||||
|
return &proto.Empty{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Log(ctx context.Context, req *proto.LogRequest) (*proto.Empty, error) {
|
||||||
|
if _, _, ok := s.getAgentIDFromContext(ctx); !ok {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "invalid session")
|
||||||
|
}
|
||||||
|
for _, entry := range req.LogEntries {
|
||||||
|
if entry.StepUuid != "" {
|
||||||
|
_ = s.jobService.ProcessLog(ctx, entry.StepUuid, entry.Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &proto.Empty{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Extend(context.Context, *proto.ExtendRequest) (*proto.Empty, error) {
|
||||||
|
return &proto.Empty{}, nil
|
||||||
|
}
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
package mqtt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Client interface {
|
|
||||||
Publish(ctx context.Context, topic string, payload []byte) error
|
|
||||||
Subscribe(topic string, handler MessageHandler) error
|
|
||||||
Disconnect()
|
|
||||||
}
|
|
||||||
|
|
||||||
type MessageHandler func(topic string, payload []byte)
|
|
||||||
|
|
||||||
type client struct {
|
|
||||||
cli mqtt.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClient(broker string, clientID string) (Client, error) {
|
|
||||||
opts := mqtt.NewClientOptions().
|
|
||||||
AddBroker(broker).
|
|
||||||
SetClientID(clientID).
|
|
||||||
SetAutoReconnect(true).
|
|
||||||
SetConnectRetry(true).
|
|
||||||
SetConnectRetryInterval(3 * time.Second)
|
|
||||||
|
|
||||||
opts.OnConnect = func(c mqtt.Client) {
|
|
||||||
fmt.Println("MQTT connected")
|
|
||||||
}
|
|
||||||
|
|
||||||
opts.OnConnectionLost = func(c mqtt.Client, err error) {
|
|
||||||
fmt.Println("MQTT connection lost:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c := mqtt.NewClient(opts)
|
|
||||||
|
|
||||||
token := c.Connect()
|
|
||||||
if ok := token.WaitTimeout(5 * time.Second); !ok {
|
|
||||||
return nil, fmt.Errorf("mqtt connect timeout")
|
|
||||||
}
|
|
||||||
if err := token.Error(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &client{cli: c}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) Publish(ctx context.Context, topic string, payload []byte) error {
|
|
||||||
token := c.cli.Publish(topic, 1, false, payload)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-waitToken(token):
|
|
||||||
return token.Error()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitToken(t mqtt.Token) <-chan struct{} {
|
|
||||||
ch := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
t.Wait()
|
|
||||||
close(ch)
|
|
||||||
}()
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) Subscribe(topic string, handler MessageHandler) error {
|
|
||||||
token := c.cli.Subscribe(topic, 1, func(client mqtt.Client, msg mqtt.Message) {
|
|
||||||
handler(msg.Topic(), msg.Payload())
|
|
||||||
})
|
|
||||||
|
|
||||||
if token.Wait() && token.Error() != nil {
|
|
||||||
return token.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) Disconnect() {
|
|
||||||
c.cli.Disconnect(250)
|
|
||||||
}
|
|
||||||
244
internal/transport/mqtt/mqtt_publisher.go
Normal file
244
internal/transport/mqtt/mqtt_publisher.go
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
package mqtt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pahomqtt "github.com/eclipse/paho.mqtt.golang"
|
||||||
|
"stream.api/internal/dto"
|
||||||
|
"stream.api/internal/service"
|
||||||
|
"stream.api/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultMQTTBrokerURL = "tcp://broker.mqtt-dashboard.com:1883"
|
||||||
|
defaultMQTTPrefix = "picpic"
|
||||||
|
defaultPublishWait = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type agentRuntime interface {
|
||||||
|
ListAgentsWithStats() []*dto.AgentWithStats
|
||||||
|
}
|
||||||
|
|
||||||
|
type mqttPublisher struct {
|
||||||
|
client pahomqtt.Client
|
||||||
|
jobService *service.JobService
|
||||||
|
agentRT agentRuntime
|
||||||
|
logger logger.Logger
|
||||||
|
prefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMQTTPublisher(jobService *service.JobService, agentRT agentRuntime, appLogger logger.Logger) (*mqttPublisher, error) {
|
||||||
|
client, err := connectPahoClient(defaultMQTTBrokerURL, fmt.Sprintf("stream-api-%d", time.Now().UnixNano()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &mqttPublisher{
|
||||||
|
client: client,
|
||||||
|
jobService: jobService,
|
||||||
|
agentRT: agentRT,
|
||||||
|
logger: appLogger,
|
||||||
|
prefix: defaultMQTTPrefix,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mqttPublisher) start(ctx context.Context) {
|
||||||
|
if p == nil || p.jobService == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go p.consumeLogs(ctx)
|
||||||
|
go p.consumeJobUpdates(ctx)
|
||||||
|
go p.consumeResources(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mqttPublisher) consumeLogs(ctx context.Context) {
|
||||||
|
ch, err := p.jobService.SubscribeJobLogs(ctx, "")
|
||||||
|
if err != nil {
|
||||||
|
p.logger.Error("Failed to subscribe job logs for MQTT", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case entry, ok := <-ch:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := p.publishJSON(p.logTopic(entry.JobID), entry); err != nil {
|
||||||
|
p.logger.Error("Failed to publish MQTT job log", "error", err, "job_id", entry.JobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mqttPublisher) consumeJobUpdates(ctx context.Context) {
|
||||||
|
ch, err := p.jobService.SubscribeJobUpdates(ctx)
|
||||||
|
if err != nil {
|
||||||
|
p.logger.Error("Failed to subscribe job updates for MQTT", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case msg, ok := <-ch:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(msg), &payload); err != nil {
|
||||||
|
p.logger.Error("Failed to decode MQTT job update payload", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
jobID, _ := payload["job_id"].(string)
|
||||||
|
if err := p.publishEvent("job_update", payload, jobID); err != nil {
|
||||||
|
p.logger.Error("Failed to publish MQTT job update", "error", err, "job_id", jobID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mqttPublisher) consumeResources(ctx context.Context) {
|
||||||
|
ch, err := p.jobService.SubscribeSystemResources(ctx)
|
||||||
|
if err != nil {
|
||||||
|
p.logger.Error("Failed to subscribe resources for MQTT", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case entry, ok := <-ch:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.publishEvent("resource_update", entry, ""); err != nil {
|
||||||
|
p.logger.Error("Failed to publish MQTT resource update", "error", err, "agent_id", entry.AgentID)
|
||||||
|
}
|
||||||
|
|
||||||
|
agent := p.findAgent(entry.AgentID)
|
||||||
|
if agent == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := p.publishEvent("agent_update", mapAgentPayload(agent), ""); err != nil {
|
||||||
|
p.logger.Error("Failed to publish MQTT agent update", "error", err, "agent_id", entry.AgentID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mqttPublisher) findAgent(agentID string) *dto.AgentWithStats {
|
||||||
|
if p == nil || p.agentRT == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, agent := range p.agentRT.ListAgentsWithStats() {
|
||||||
|
if agent != nil && agent.Agent != nil && agent.Agent.ID == agentID {
|
||||||
|
return agent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mqttPublisher) publishEvent(eventType string, payload any, jobID string) error {
|
||||||
|
message := mqttEvent{Type: eventType, Payload: payload}
|
||||||
|
if jobID != "" {
|
||||||
|
if err := p.publishJSON(p.jobTopic(jobID), message); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p.publishJSON(p.eventsTopic(), message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mqttPublisher) publishJSON(topic string, payload any) error {
|
||||||
|
encoded, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return p.publish(topic, encoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mqttPublisher) publish(topic string, payload []byte) error {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return publishPahoMessage(p.client, topic, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mqttPublisher) logTopic(jobID string) string {
|
||||||
|
return fmt.Sprintf("%s/logs/%s", p.prefix, jobID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mqttPublisher) jobTopic(jobID string) string {
|
||||||
|
return fmt.Sprintf("%s/job/%s", p.prefix, jobID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mqttPublisher) eventsTopic() string {
|
||||||
|
return fmt.Sprintf("%s/events", p.prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapAgentPayload(agent *dto.AgentWithStats) map[string]any {
|
||||||
|
if agent == nil || agent.Agent == nil {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"id": agent.Agent.ID,
|
||||||
|
"name": agent.Name,
|
||||||
|
"platform": agent.Platform,
|
||||||
|
"backend": agent.Backend,
|
||||||
|
"version": agent.Version,
|
||||||
|
"capacity": agent.Capacity,
|
||||||
|
"status": string(agent.Status),
|
||||||
|
"cpu": agent.CPU,
|
||||||
|
"ram": agent.RAM,
|
||||||
|
"last_heartbeat": agent.LastHeartbeat,
|
||||||
|
"created_at": agent.CreatedAt,
|
||||||
|
"updated_at": agent.UpdatedAt,
|
||||||
|
"active_job_count": agent.ActiveJobCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mqttEvent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Payload any `json:"payload"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MQTTBootstrap struct{ *mqttPublisher }
|
||||||
|
|
||||||
|
func NewMQTTBootstrap(jobService *service.JobService, agentRT agentRuntime, appLogger logger.Logger) (*MQTTBootstrap, error) {
|
||||||
|
publisher, err := newMQTTPublisher(jobService, agentRT, appLogger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &MQTTBootstrap{mqttPublisher: publisher}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *MQTTBootstrap) Start(ctx context.Context) {
|
||||||
|
if b == nil || b.mqttPublisher == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.mqttPublisher.start(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *MQTTBootstrap) Client() pahomqtt.Client {
|
||||||
|
if b == nil || b.mqttPublisher == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return b.client
|
||||||
|
}
|
||||||
|
|
||||||
|
func PublishAgentMQTTEvent(client pahomqtt.Client, appLogger logger.Logger, eventType string, agent *dto.AgentWithStats) {
|
||||||
|
publishMQTTEvent(client, appLogger, defaultMQTTPrefix, mqttEvent{
|
||||||
|
Type: eventType,
|
||||||
|
Payload: mapAgentPayload(agent),
|
||||||
|
})
|
||||||
|
}
|
||||||
59
internal/transport/mqtt/paho_helpers.go
Normal file
59
internal/transport/mqtt/paho_helpers.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package mqtt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pahomqtt "github.com/eclipse/paho.mqtt.golang"
|
||||||
|
"stream.api/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func connectPahoClient(broker, clientID string) (pahomqtt.Client, error) {
|
||||||
|
opts := pahomqtt.NewClientOptions().
|
||||||
|
AddBroker(broker).
|
||||||
|
SetClientID(clientID).
|
||||||
|
SetAutoReconnect(true).
|
||||||
|
SetConnectRetry(true).
|
||||||
|
SetKeepAlive(60 * time.Second).
|
||||||
|
SetPingTimeout(10 * time.Second).
|
||||||
|
SetConnectRetryInterval(3 * time.Second)
|
||||||
|
|
||||||
|
client := pahomqtt.NewClient(opts)
|
||||||
|
token := client.Connect()
|
||||||
|
if ok := token.WaitTimeout(defaultPublishWait); !ok {
|
||||||
|
return nil, fmt.Errorf("mqtt connect timeout")
|
||||||
|
}
|
||||||
|
if err := token.Error(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func publishMQTTEvent(client pahomqtt.Client, appLogger logger.Logger, prefix string, event mqttEvent) {
|
||||||
|
if client == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded, err := json.Marshal(event)
|
||||||
|
if err != nil {
|
||||||
|
appLogger.Error("Failed to marshal MQTT event", "error", err, "type", event.Type)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := publishPahoMessage(client, fmt.Sprintf("%s/events", prefix), encoded); err != nil {
|
||||||
|
appLogger.Error("Failed to publish MQTT event", "error", err, "type", event.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func publishPahoMessage(client pahomqtt.Client, topic string, payload []byte) error {
|
||||||
|
if client == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
token := client.Publish(topic, 0, false, payload)
|
||||||
|
if ok := token.WaitTimeout(defaultPublishWait); !ok {
|
||||||
|
return fmt.Errorf("mqtt publish timeout")
|
||||||
|
}
|
||||||
|
return token.Error()
|
||||||
|
}
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
package video
|
|
||||||
|
|
||||||
import (
|
|
||||||
runtimeservices "stream.api/internal/video/runtime/services"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AgentWithStats = runtimeservices.AgentWithStats
|
|
||||||
|
|
||||||
type PaginatedJobs = runtimeservices.PaginatedJobs
|
|
||||||
|
|
||||||
var ErrInvalidJobCursor = runtimeservices.ErrInvalidJobCursor
|
|
||||||
|
|
||||||
type AgentRuntime interface {
|
|
||||||
ListAgentsWithStats() []*AgentWithStats
|
|
||||||
SendCommand(agentID string, cmd string) bool
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
package domain
|
|
||||||
|
|
||||||
type JobStatus string
|
|
||||||
|
|
||||||
const (
|
|
||||||
JobStatusPending JobStatus = "pending"
|
|
||||||
JobStatusRunning JobStatus = "running"
|
|
||||||
JobStatusSuccess JobStatus = "success"
|
|
||||||
JobStatusFailure JobStatus = "failure"
|
|
||||||
JobStatusCancelled JobStatus = "cancelled"
|
|
||||||
)
|
|
||||||
@@ -1,363 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
grpcpkg "google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/metadata"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"stream.api/internal/video/runtime/domain"
|
|
||||||
"stream.api/internal/video/runtime/proto"
|
|
||||||
"stream.api/internal/video/runtime/services"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Server struct {
|
|
||||||
proto.UnimplementedWoodpeckerServer
|
|
||||||
proto.UnimplementedWoodpeckerAuthServer
|
|
||||||
jobService *services.JobService
|
|
||||||
agentManager *AgentManager
|
|
||||||
agentSecret string
|
|
||||||
sessions sync.Map
|
|
||||||
agentJobs sync.Map
|
|
||||||
onAgentEvent func(string, *services.AgentWithStats)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServer(jobService *services.JobService, agentSecret string) *Server {
|
|
||||||
return &Server{jobService: jobService, agentManager: NewAgentManager(), agentSecret: agentSecret}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) SetAgentEventHandler(handler func(string, *services.AgentWithStats)) {
|
|
||||||
s.onAgentEvent = handler
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Register(grpcServer grpcpkg.ServiceRegistrar) {
|
|
||||||
proto.RegisterWoodpeckerServer(grpcServer, s)
|
|
||||||
proto.RegisterWoodpeckerAuthServer(grpcServer, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) SendCommand(agentID string, cmd string) bool {
|
|
||||||
return s.agentManager.SendCommand(agentID, cmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) ListAgents() []*domain.Agent { return s.agentManager.ListAll() }
|
|
||||||
|
|
||||||
func (s *Server) ListAgentsWithStats() []*services.AgentWithStats {
|
|
||||||
agents := s.agentManager.ListAll()
|
|
||||||
result := make([]*services.AgentWithStats, 0, len(agents))
|
|
||||||
for _, agent := range agents {
|
|
||||||
result = append(result, &services.AgentWithStats{Agent: agent, ActiveJobCount: int64(len(s.getAgentJobs(agent.ID)))})
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getAgentWithStats(agentID string) *services.AgentWithStats {
|
|
||||||
for _, agent := range s.ListAgentsWithStats() {
|
|
||||||
if agent != nil && agent.Agent != nil && agent.ID == agentID {
|
|
||||||
return agent
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Version(context.Context, *proto.Empty) (*proto.VersionResponse, error) {
|
|
||||||
return &proto.VersionResponse{GrpcVersion: 15, ServerVersion: "stream.api"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateToken() string {
|
|
||||||
b := make([]byte, 16)
|
|
||||||
_, _ = rand.Read(b)
|
|
||||||
return hex.EncodeToString(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateAgentID() string {
|
|
||||||
return strconv.FormatInt(time.Now().UnixNano(), 10)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getAgentIDFromContext(ctx context.Context) (string, string, bool) {
|
|
||||||
md, ok := metadata.FromIncomingContext(ctx)
|
|
||||||
if !ok {
|
|
||||||
return "", "", false
|
|
||||||
}
|
|
||||||
tokens := md.Get("token")
|
|
||||||
if len(tokens) == 0 {
|
|
||||||
return "", "", false
|
|
||||||
}
|
|
||||||
token := tokens[0]
|
|
||||||
if id, ok := s.sessions.Load(token); ok {
|
|
||||||
return id.(string), token, true
|
|
||||||
}
|
|
||||||
return "", "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Next(context.Context, *proto.NextRequest) (*proto.NextResponse, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "use StreamJobs")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) StreamJobs(_ *proto.StreamOptions, stream grpcpkg.ServerStreamingServer[proto.Workflow]) error {
|
|
||||||
ctx := stream.Context()
|
|
||||||
agentID, _, ok := s.getAgentIDFromContext(ctx)
|
|
||||||
if !ok {
|
|
||||||
return status.Error(codes.Unauthenticated, "invalid or missing token")
|
|
||||||
}
|
|
||||||
s.agentManager.UpdateHeartbeat(agentID)
|
|
||||||
cancelCh, _ := s.jobService.SubscribeCancel(ctx, agentID)
|
|
||||||
commandCh, _ := s.agentManager.GetCommandChannel(agentID)
|
|
||||||
ticker := time.NewTicker(2 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case cmd := <-commandCh:
|
|
||||||
payload, _ := json.Marshal(map[string]any{"image": "alpine", "commands": []string{"echo 'System Command'"}, "environment": map[string]string{}, "action": cmd})
|
|
||||||
if err := stream.Send(&proto.Workflow{Id: fmt.Sprintf("cmd-%s-%d", agentID, time.Now().UnixNano()), Timeout: 300, Payload: payload}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case jobID := <-cancelCh:
|
|
||||||
if s.isJobAssigned(agentID, jobID) {
|
|
||||||
if err := stream.Send(&proto.Workflow{Id: jobID, Cancel: true}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil
|
|
||||||
case <-ticker.C:
|
|
||||||
s.agentManager.UpdateHeartbeat(agentID)
|
|
||||||
jobCtx, cancel := context.WithTimeout(ctx, time.Second)
|
|
||||||
job, err := s.jobService.GetNextJob(jobCtx)
|
|
||||||
cancel()
|
|
||||||
if err != nil || job == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s.trackJobAssignment(agentID, job.ID)
|
|
||||||
if err := s.jobService.AssignJob(ctx, job.ID, agentID); err != nil {
|
|
||||||
s.untrackJobAssignment(agentID, job.ID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
var config map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(*job.Config), &config); err != nil {
|
|
||||||
_ = s.jobService.UpdateJobStatus(ctx, job.ID, domain.JobStatusFailure)
|
|
||||||
s.untrackJobAssignment(agentID, job.ID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
image, _ := config["image"].(string)
|
|
||||||
if image == "" {
|
|
||||||
image = "alpine"
|
|
||||||
}
|
|
||||||
commands := []string{"echo 'No commands specified'"}
|
|
||||||
if raw, ok := config["commands"].([]any); ok && len(raw) > 0 {
|
|
||||||
commands = commands[:0]
|
|
||||||
for _, item := range raw {
|
|
||||||
if text, ok := item.(string); ok {
|
|
||||||
commands = append(commands, text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(commands) == 0 {
|
|
||||||
commands = []string{"echo 'No commands specified'"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
payload, _ := json.Marshal(map[string]any{"image": image, "commands": commands, "environment": map[string]string{}})
|
|
||||||
// Sau này xem xét có cần cho job.TimeLimit vào db không?
|
|
||||||
// Hiện tại để đơn giản thì cứ để mặc định timeout 1h, nếu job nào cần timeout ngắn hơn thì tự lo trong commands của nó
|
|
||||||
if err := stream.Send(&proto.Workflow{Id: job.ID, Timeout: 60 * 60 * 1000, Payload: payload}); err != nil {
|
|
||||||
_ = s.jobService.UpdateJobStatus(ctx, job.ID, domain.JobStatusPending)
|
|
||||||
s.untrackJobAssignment(agentID, job.ID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) SubmitStatus(stream grpcpkg.ClientStreamingServer[proto.StatusUpdate, proto.Empty]) error {
|
|
||||||
ctx := stream.Context()
|
|
||||||
agentID, _, ok := s.getAgentIDFromContext(ctx)
|
|
||||||
if !ok {
|
|
||||||
return status.Error(codes.Unauthenticated, "invalid or missing token")
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
update, err := stream.Recv()
|
|
||||||
if err != nil {
|
|
||||||
return stream.SendAndClose(&proto.Empty{})
|
|
||||||
}
|
|
||||||
switch update.Type {
|
|
||||||
case 0, 1:
|
|
||||||
_ = s.jobService.ProcessLog(ctx, update.StepUuid, update.Data)
|
|
||||||
case 4:
|
|
||||||
var progress float64
|
|
||||||
fmt.Sscanf(string(update.Data), "%f", &progress)
|
|
||||||
_ = s.jobService.UpdateJobProgress(ctx, update.StepUuid, progress)
|
|
||||||
case 5:
|
|
||||||
var stats struct {
|
|
||||||
CPU float64 `json:"cpu"`
|
|
||||||
RAM float64 `json:"ram"`
|
|
||||||
}
|
|
||||||
if json.Unmarshal(update.Data, &stats) == nil {
|
|
||||||
s.agentManager.UpdateResources(agentID, stats.CPU, stats.RAM)
|
|
||||||
if s.onAgentEvent != nil {
|
|
||||||
s.onAgentEvent("agent_update", s.getAgentWithStats(agentID))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ = s.jobService.PublishSystemResources(ctx, agentID, update.Data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Init(ctx context.Context, req *proto.InitRequest) (*proto.Empty, error) {
|
|
||||||
if err := s.jobService.UpdateJobStatus(ctx, req.Id, domain.JobStatusRunning); err != nil {
|
|
||||||
return nil, status.Error(codes.Internal, "failed to update job status")
|
|
||||||
}
|
|
||||||
return &proto.Empty{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Wait(context.Context, *proto.WaitRequest) (*proto.WaitResponse, error) {
|
|
||||||
return &proto.WaitResponse{Canceled: false}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Done(ctx context.Context, req *proto.DoneRequest) (*proto.Empty, error) {
|
|
||||||
agentID, _, ok := s.getAgentIDFromContext(ctx)
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Error(codes.Unauthenticated, "invalid session")
|
|
||||||
}
|
|
||||||
jobStatus := domain.JobStatusSuccess
|
|
||||||
if req.State != nil && req.State.Error != "" {
|
|
||||||
jobStatus = domain.JobStatusFailure
|
|
||||||
}
|
|
||||||
if err := s.jobService.UpdateJobStatus(ctx, req.Id, jobStatus); err != nil {
|
|
||||||
return nil, status.Error(codes.Internal, "failed to update job status")
|
|
||||||
}
|
|
||||||
s.untrackJobAssignment(agentID, req.Id)
|
|
||||||
return &proto.Empty{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Update(context.Context, *proto.UpdateRequest) (*proto.Empty, error) {
|
|
||||||
return &proto.Empty{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Log(ctx context.Context, req *proto.LogRequest) (*proto.Empty, error) {
|
|
||||||
if _, _, ok := s.getAgentIDFromContext(ctx); !ok {
|
|
||||||
return nil, status.Error(codes.Unauthenticated, "invalid session")
|
|
||||||
}
|
|
||||||
for _, entry := range req.LogEntries {
|
|
||||||
if entry.StepUuid != "" {
|
|
||||||
_ = s.jobService.ProcessLog(ctx, entry.StepUuid, entry.Data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &proto.Empty{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Extend(context.Context, *proto.ExtendRequest) (*proto.Empty, error) {
|
|
||||||
return &proto.Empty{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) RegisterAgent(ctx context.Context, req *proto.RegisterAgentRequest) (*proto.RegisterAgentResponse, error) {
|
|
||||||
if req.Info == nil {
|
|
||||||
return nil, status.Error(codes.InvalidArgument, "connection info is required")
|
|
||||||
}
|
|
||||||
id, _, ok := s.getAgentIDFromContext(ctx)
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Error(codes.Unauthenticated, "invalid session")
|
|
||||||
}
|
|
||||||
hostname := ""
|
|
||||||
if req.Info.CustomLabels != nil {
|
|
||||||
hostname = req.Info.CustomLabels["hostname"]
|
|
||||||
}
|
|
||||||
name := hostname
|
|
||||||
if name == "" {
|
|
||||||
name = fmt.Sprintf("agent-%s", id)
|
|
||||||
}
|
|
||||||
s.agentManager.Register(id, name, req.Info.Platform, req.Info.Backend, req.Info.Version, req.Info.Capacity)
|
|
||||||
if s.onAgentEvent != nil {
|
|
||||||
s.onAgentEvent("agent_update", s.getAgentWithStats(id))
|
|
||||||
}
|
|
||||||
return &proto.RegisterAgentResponse{AgentId: id}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) UnregisterAgent(ctx context.Context, _ *proto.Empty) (*proto.Empty, error) {
|
|
||||||
agentID, token, ok := s.getAgentIDFromContext(ctx)
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Error(codes.Unauthenticated, "invalid session")
|
|
||||||
}
|
|
||||||
for _, jobID := range s.getAgentJobs(agentID) {
|
|
||||||
_ = s.jobService.UpdateJobStatus(ctx, jobID, domain.JobStatusFailure)
|
|
||||||
s.untrackJobAssignment(agentID, jobID)
|
|
||||||
}
|
|
||||||
s.sessions.Delete(token)
|
|
||||||
s.agentJobs.Delete(agentID)
|
|
||||||
agent := s.getAgentWithStats(agentID)
|
|
||||||
s.agentManager.Unregister(agentID)
|
|
||||||
if s.onAgentEvent != nil {
|
|
||||||
s.onAgentEvent("agent_update", agent)
|
|
||||||
}
|
|
||||||
return &proto.Empty{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) ReportHealth(ctx context.Context, _ *proto.ReportHealthRequest) (*proto.Empty, error) {
|
|
||||||
agentID, _, ok := s.getAgentIDFromContext(ctx)
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Error(codes.Unauthenticated, "invalid session")
|
|
||||||
}
|
|
||||||
s.agentManager.UpdateHeartbeat(agentID)
|
|
||||||
return &proto.Empty{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Auth(ctx context.Context, req *proto.AuthRequest) (*proto.AuthResponse, error) {
|
|
||||||
if s.agentSecret != "" && req.AgentToken != s.agentSecret {
|
|
||||||
return nil, status.Error(codes.Unauthenticated, "invalid agent secret")
|
|
||||||
}
|
|
||||||
agentID := req.AgentId
|
|
||||||
if len(agentID) > 6 && agentID[:6] == "agent-" {
|
|
||||||
agentID = agentID[6:]
|
|
||||||
}
|
|
||||||
if agentID == "" {
|
|
||||||
agentID = generateAgentID()
|
|
||||||
}
|
|
||||||
accessToken := generateToken()
|
|
||||||
s.sessions.Store(accessToken, agentID)
|
|
||||||
return &proto.AuthResponse{Status: "ok", AgentId: agentID, AccessToken: accessToken}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) trackJobAssignment(agentID, jobID string) {
|
|
||||||
jobSetInterface, _ := s.agentJobs.LoadOrStore(agentID, &sync.Map{})
|
|
||||||
if jobSet, ok := jobSetInterface.(*sync.Map); ok {
|
|
||||||
jobSet.Store(jobID, true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) untrackJobAssignment(agentID, jobID string) {
|
|
||||||
if jobSetInterface, ok := s.agentJobs.Load(agentID); ok {
|
|
||||||
if jobSet, ok := jobSetInterface.(*sync.Map); ok {
|
|
||||||
jobSet.Delete(jobID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) isJobAssigned(agentID, jobID string) bool {
|
|
||||||
if jobSetInterface, ok := s.agentJobs.Load(agentID); ok {
|
|
||||||
if jobSet, ok := jobSetInterface.(*sync.Map); ok {
|
|
||||||
_, found := jobSet.Load(jobID)
|
|
||||||
return found
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getAgentJobs(agentID string) []string {
|
|
||||||
jobs := []string{}
|
|
||||||
if jobSetInterface, ok := s.agentJobs.Load(agentID); ok {
|
|
||||||
if jobSet, ok := jobSetInterface.(*sync.Map); ok {
|
|
||||||
jobSet.Range(func(key, _ any) bool {
|
|
||||||
if jobID, ok := key.(string); ok {
|
|
||||||
jobs = append(jobs, jobID)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jobs
|
|
||||||
}
|
|
||||||
@@ -1,240 +0,0 @@
|
|||||||
package runtime
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
|
||||||
"stream.api/internal/video/runtime/domain"
|
|
||||||
"stream.api/internal/video/runtime/services"
|
|
||||||
"stream.api/pkg/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultMQTTBrokerURL = "tcp://broker.mqtt-dashboard.com:1883"
|
|
||||||
defaultMQTTPrefix = "picpic"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mqttPublisher struct {
|
|
||||||
client mqtt.Client
|
|
||||||
jobService *services.JobService
|
|
||||||
agentRT interface {
|
|
||||||
ListAgentsWithStats() []*services.AgentWithStats
|
|
||||||
}
|
|
||||||
logger logger.Logger
|
|
||||||
prefix string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMQTTPublisher(jobService *services.JobService, agentRT interface {
|
|
||||||
ListAgentsWithStats() []*services.AgentWithStats
|
|
||||||
}, appLogger logger.Logger) (*mqttPublisher, error) {
|
|
||||||
opts := mqtt.NewClientOptions()
|
|
||||||
opts.AddBroker(defaultMQTTBrokerURL)
|
|
||||||
opts.SetClientID(fmt.Sprintf("stream-api-%d", time.Now().UnixNano()))
|
|
||||||
opts.SetKeepAlive(60 * time.Second)
|
|
||||||
opts.SetPingTimeout(10 * time.Second)
|
|
||||||
opts.SetAutoReconnect(true)
|
|
||||||
|
|
||||||
client := mqtt.NewClient(opts)
|
|
||||||
token := client.Connect()
|
|
||||||
if token.Wait() && token.Error() != nil {
|
|
||||||
return nil, token.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &mqttPublisher{
|
|
||||||
client: client,
|
|
||||||
jobService: jobService,
|
|
||||||
agentRT: agentRT,
|
|
||||||
logger: appLogger,
|
|
||||||
prefix: defaultMQTTPrefix,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *mqttPublisher) start(ctx context.Context) {
|
|
||||||
if p == nil || p.jobService == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go p.consumeLogs(ctx)
|
|
||||||
go p.consumeJobUpdates(ctx)
|
|
||||||
go p.consumeResources(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *mqttPublisher) consumeLogs(ctx context.Context) {
|
|
||||||
ch, err := p.jobService.SubscribeJobLogs(ctx, "")
|
|
||||||
if err != nil {
|
|
||||||
p.logger.Error("Failed to subscribe job logs for MQTT", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case entry, ok := <-ch:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
payload, _ := json.Marshal(entry)
|
|
||||||
p.publish(fmt.Sprintf("%s/logs/%s", p.prefix, entry.JobID), payload)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *mqttPublisher) consumeJobUpdates(ctx context.Context) {
|
|
||||||
ch, err := p.jobService.SubscribeJobUpdates(ctx)
|
|
||||||
if err != nil {
|
|
||||||
p.logger.Error("Failed to subscribe job updates for MQTT", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case msg, ok := <-ch:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var inner map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(msg), &inner); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
jobID, _ := inner["job_id"].(string)
|
|
||||||
eventPayload, _ := json.Marshal(map[string]any{
|
|
||||||
"type": "job_update",
|
|
||||||
"payload": inner,
|
|
||||||
})
|
|
||||||
if jobID != "" {
|
|
||||||
p.publish(fmt.Sprintf("%s/job/%s", p.prefix, jobID), eventPayload)
|
|
||||||
}
|
|
||||||
p.publish(fmt.Sprintf("%s/events", p.prefix), eventPayload)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *mqttPublisher) consumeResources(ctx context.Context) {
|
|
||||||
ch, err := p.jobService.SubscribeSystemResources(ctx)
|
|
||||||
if err != nil {
|
|
||||||
p.logger.Error("Failed to subscribe resources for MQTT", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case entry, ok := <-ch:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resourcePayload, _ := json.Marshal(map[string]any{
|
|
||||||
"type": "resource_update",
|
|
||||||
"payload": entry,
|
|
||||||
})
|
|
||||||
p.publish(fmt.Sprintf("%s/events", p.prefix), resourcePayload)
|
|
||||||
if p.agentRT != nil {
|
|
||||||
for _, agent := range p.agentRT.ListAgentsWithStats() {
|
|
||||||
if agent == nil || agent.Agent == nil || agent.ID != entry.AgentID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
agentPayload, _ := json.Marshal(map[string]any{
|
|
||||||
"type": "agent_update",
|
|
||||||
"payload": mapAgentPayload(agent),
|
|
||||||
})
|
|
||||||
p.publish(fmt.Sprintf("%s/events", p.prefix), agentPayload)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *mqttPublisher) publish(topic string, payload []byte) {
|
|
||||||
if p == nil || p.client == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
token := p.client.Publish(topic, 0, false, payload)
|
|
||||||
token.WaitTimeout(5 * time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapAgentPayload(agent *services.AgentWithStats) map[string]any {
|
|
||||||
if agent == nil || agent.Agent == nil {
|
|
||||||
return map[string]any{}
|
|
||||||
}
|
|
||||||
return map[string]any{
|
|
||||||
"id": agent.ID,
|
|
||||||
"name": agent.Name,
|
|
||||||
"platform": agent.Platform,
|
|
||||||
"backend": agent.Backend,
|
|
||||||
"version": agent.Version,
|
|
||||||
"capacity": agent.Capacity,
|
|
||||||
"status": string(agent.Status),
|
|
||||||
"cpu": agent.CPU,
|
|
||||||
"ram": agent.RAM,
|
|
||||||
"last_heartbeat": agent.LastHeartbeat,
|
|
||||||
"created_at": agent.CreatedAt,
|
|
||||||
"updated_at": agent.UpdatedAt,
|
|
||||||
"active_job_count": agent.ActiveJobCount,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func publishAgentEvent(client mqtt.Client, appLogger logger.Logger, eventType string, agent *services.AgentWithStats) {
|
|
||||||
if client == nil || agent == nil || agent.Agent == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
payload, err := json.Marshal(map[string]any{
|
|
||||||
"type": eventType,
|
|
||||||
"payload": mapAgentPayload(agent),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
appLogger.Error("Failed to marshal agent MQTT event", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
token := client.Publish(fmt.Sprintf("%s/events", defaultMQTTPrefix), 0, false, payload)
|
|
||||||
token.WaitTimeout(5 * time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
func publishResourceEvent(client mqtt.Client, appLogger logger.Logger, entry domain.SystemResource) {
|
|
||||||
if client == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
payload, err := json.Marshal(map[string]any{
|
|
||||||
"type": "resource_update",
|
|
||||||
"payload": entry,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
appLogger.Error("Failed to marshal resource MQTT event", "error", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
token := client.Publish(fmt.Sprintf("%s/events", defaultMQTTPrefix), 0, false, payload)
|
|
||||||
token.WaitTimeout(5 * time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MQTTBootstrap struct{ *mqttPublisher }
|
|
||||||
|
|
||||||
func NewMQTTBootstrap(jobService *services.JobService, agentRT interface {
|
|
||||||
ListAgentsWithStats() []*services.AgentWithStats
|
|
||||||
}, appLogger logger.Logger) (*MQTTBootstrap, error) {
|
|
||||||
publisher, err := newMQTTPublisher(jobService, agentRT, appLogger)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &MQTTBootstrap{mqttPublisher: publisher}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *MQTTBootstrap) Start(ctx context.Context) {
|
|
||||||
if b == nil || b.mqttPublisher == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.mqttPublisher.start(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *MQTTBootstrap) Client() mqtt.Client {
|
|
||||||
if b == nil || b.mqttPublisher == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return b.client
|
|
||||||
}
|
|
||||||
|
|
||||||
func PublishAgentMQTTEvent(client mqtt.Client, appLogger logger.Logger, eventType string, agent *services.AgentWithStats) {
|
|
||||||
publishAgentEvent(client, appLogger, eventType, agent)
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,697 +0,0 @@
|
|||||||
// Copyright 2021 Woodpecker Authors
|
|
||||||
// Copyright 2011 Drone.IO Inc.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
|
||||||
// versions:
|
|
||||||
// - protoc-gen-go-grpc v1.6.1
|
|
||||||
// - protoc v3.21.12
|
|
||||||
// source: proto/woodpecker.proto
|
|
||||||
|
|
||||||
package proto
|
|
||||||
|
|
||||||
import (
|
|
||||||
context "context"
|
|
||||||
grpc "google.golang.org/grpc"
|
|
||||||
codes "google.golang.org/grpc/codes"
|
|
||||||
status "google.golang.org/grpc/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
// This is a compile-time assertion to ensure that this generated file
|
|
||||||
// is compatible with the grpc package it is being compiled against.
|
|
||||||
// Requires gRPC-Go v1.64.0 or later.
|
|
||||||
const _ = grpc.SupportPackageIsVersion9
|
|
||||||
|
|
||||||
const (
|
|
||||||
Woodpecker_Version_FullMethodName = "/proto.Woodpecker/Version"
|
|
||||||
Woodpecker_Next_FullMethodName = "/proto.Woodpecker/Next"
|
|
||||||
Woodpecker_Init_FullMethodName = "/proto.Woodpecker/Init"
|
|
||||||
Woodpecker_Wait_FullMethodName = "/proto.Woodpecker/Wait"
|
|
||||||
Woodpecker_Done_FullMethodName = "/proto.Woodpecker/Done"
|
|
||||||
Woodpecker_Extend_FullMethodName = "/proto.Woodpecker/Extend"
|
|
||||||
Woodpecker_Update_FullMethodName = "/proto.Woodpecker/Update"
|
|
||||||
Woodpecker_Log_FullMethodName = "/proto.Woodpecker/Log"
|
|
||||||
Woodpecker_RegisterAgent_FullMethodName = "/proto.Woodpecker/RegisterAgent"
|
|
||||||
Woodpecker_UnregisterAgent_FullMethodName = "/proto.Woodpecker/UnregisterAgent"
|
|
||||||
Woodpecker_ReportHealth_FullMethodName = "/proto.Woodpecker/ReportHealth"
|
|
||||||
Woodpecker_StreamJobs_FullMethodName = "/proto.Woodpecker/StreamJobs"
|
|
||||||
Woodpecker_SubmitStatus_FullMethodName = "/proto.Woodpecker/SubmitStatus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WoodpeckerClient is the client API for Woodpecker service.
|
|
||||||
//
|
|
||||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
|
||||||
//
|
|
||||||
// Woodpecker Server Service
|
|
||||||
type WoodpeckerClient interface {
|
|
||||||
Version(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*VersionResponse, error)
|
|
||||||
Next(ctx context.Context, in *NextRequest, opts ...grpc.CallOption) (*NextResponse, error)
|
|
||||||
Init(ctx context.Context, in *InitRequest, opts ...grpc.CallOption) (*Empty, error)
|
|
||||||
Wait(ctx context.Context, in *WaitRequest, opts ...grpc.CallOption) (*WaitResponse, error)
|
|
||||||
Done(ctx context.Context, in *DoneRequest, opts ...grpc.CallOption) (*Empty, error)
|
|
||||||
Extend(ctx context.Context, in *ExtendRequest, opts ...grpc.CallOption) (*Empty, error)
|
|
||||||
Update(ctx context.Context, in *UpdateRequest, opts ...grpc.CallOption) (*Empty, error)
|
|
||||||
Log(ctx context.Context, in *LogRequest, opts ...grpc.CallOption) (*Empty, error)
|
|
||||||
RegisterAgent(ctx context.Context, in *RegisterAgentRequest, opts ...grpc.CallOption) (*RegisterAgentResponse, error)
|
|
||||||
UnregisterAgent(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
|
|
||||||
ReportHealth(ctx context.Context, in *ReportHealthRequest, opts ...grpc.CallOption) (*Empty, error)
|
|
||||||
// New Streaming RPCs
|
|
||||||
StreamJobs(ctx context.Context, in *StreamOptions, opts ...grpc.CallOption) (grpc.ServerStreamingClient[Workflow], error)
|
|
||||||
SubmitStatus(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[StatusUpdate, Empty], error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type woodpeckerClient struct {
|
|
||||||
cc grpc.ClientConnInterface
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWoodpeckerClient(cc grpc.ClientConnInterface) WoodpeckerClient {
|
|
||||||
return &woodpeckerClient{cc}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) Version(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*VersionResponse, error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
out := new(VersionResponse)
|
|
||||||
err := c.cc.Invoke(ctx, Woodpecker_Version_FullMethodName, in, out, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) Next(ctx context.Context, in *NextRequest, opts ...grpc.CallOption) (*NextResponse, error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
out := new(NextResponse)
|
|
||||||
err := c.cc.Invoke(ctx, Woodpecker_Next_FullMethodName, in, out, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) Init(ctx context.Context, in *InitRequest, opts ...grpc.CallOption) (*Empty, error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
out := new(Empty)
|
|
||||||
err := c.cc.Invoke(ctx, Woodpecker_Init_FullMethodName, in, out, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) Wait(ctx context.Context, in *WaitRequest, opts ...grpc.CallOption) (*WaitResponse, error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
out := new(WaitResponse)
|
|
||||||
err := c.cc.Invoke(ctx, Woodpecker_Wait_FullMethodName, in, out, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) Done(ctx context.Context, in *DoneRequest, opts ...grpc.CallOption) (*Empty, error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
out := new(Empty)
|
|
||||||
err := c.cc.Invoke(ctx, Woodpecker_Done_FullMethodName, in, out, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) Extend(ctx context.Context, in *ExtendRequest, opts ...grpc.CallOption) (*Empty, error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
out := new(Empty)
|
|
||||||
err := c.cc.Invoke(ctx, Woodpecker_Extend_FullMethodName, in, out, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) Update(ctx context.Context, in *UpdateRequest, opts ...grpc.CallOption) (*Empty, error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
out := new(Empty)
|
|
||||||
err := c.cc.Invoke(ctx, Woodpecker_Update_FullMethodName, in, out, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) Log(ctx context.Context, in *LogRequest, opts ...grpc.CallOption) (*Empty, error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
out := new(Empty)
|
|
||||||
err := c.cc.Invoke(ctx, Woodpecker_Log_FullMethodName, in, out, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) RegisterAgent(ctx context.Context, in *RegisterAgentRequest, opts ...grpc.CallOption) (*RegisterAgentResponse, error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
out := new(RegisterAgentResponse)
|
|
||||||
err := c.cc.Invoke(ctx, Woodpecker_RegisterAgent_FullMethodName, in, out, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) UnregisterAgent(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
out := new(Empty)
|
|
||||||
err := c.cc.Invoke(ctx, Woodpecker_UnregisterAgent_FullMethodName, in, out, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) ReportHealth(ctx context.Context, in *ReportHealthRequest, opts ...grpc.CallOption) (*Empty, error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
out := new(Empty)
|
|
||||||
err := c.cc.Invoke(ctx, Woodpecker_ReportHealth_FullMethodName, in, out, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) StreamJobs(ctx context.Context, in *StreamOptions, opts ...grpc.CallOption) (grpc.ServerStreamingClient[Workflow], error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
stream, err := c.cc.NewStream(ctx, &Woodpecker_ServiceDesc.Streams[0], Woodpecker_StreamJobs_FullMethodName, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
x := &grpc.GenericClientStream[StreamOptions, Workflow]{ClientStream: stream}
|
|
||||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := x.ClientStream.CloseSend(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return x, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
|
||||||
type Woodpecker_StreamJobsClient = grpc.ServerStreamingClient[Workflow]
|
|
||||||
|
|
||||||
func (c *woodpeckerClient) SubmitStatus(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[StatusUpdate, Empty], error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
stream, err := c.cc.NewStream(ctx, &Woodpecker_ServiceDesc.Streams[1], Woodpecker_SubmitStatus_FullMethodName, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
x := &grpc.GenericClientStream[StatusUpdate, Empty]{ClientStream: stream}
|
|
||||||
return x, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
|
||||||
type Woodpecker_SubmitStatusClient = grpc.ClientStreamingClient[StatusUpdate, Empty]
|
|
||||||
|
|
||||||
// WoodpeckerServer is the server API for Woodpecker service.
|
|
||||||
// All implementations must embed UnimplementedWoodpeckerServer
|
|
||||||
// for forward compatibility.
|
|
||||||
//
|
|
||||||
// Woodpecker Server Service
|
|
||||||
type WoodpeckerServer interface {
|
|
||||||
Version(context.Context, *Empty) (*VersionResponse, error)
|
|
||||||
Next(context.Context, *NextRequest) (*NextResponse, error)
|
|
||||||
Init(context.Context, *InitRequest) (*Empty, error)
|
|
||||||
Wait(context.Context, *WaitRequest) (*WaitResponse, error)
|
|
||||||
Done(context.Context, *DoneRequest) (*Empty, error)
|
|
||||||
Extend(context.Context, *ExtendRequest) (*Empty, error)
|
|
||||||
Update(context.Context, *UpdateRequest) (*Empty, error)
|
|
||||||
Log(context.Context, *LogRequest) (*Empty, error)
|
|
||||||
RegisterAgent(context.Context, *RegisterAgentRequest) (*RegisterAgentResponse, error)
|
|
||||||
UnregisterAgent(context.Context, *Empty) (*Empty, error)
|
|
||||||
ReportHealth(context.Context, *ReportHealthRequest) (*Empty, error)
|
|
||||||
// New Streaming RPCs
|
|
||||||
StreamJobs(*StreamOptions, grpc.ServerStreamingServer[Workflow]) error
|
|
||||||
SubmitStatus(grpc.ClientStreamingServer[StatusUpdate, Empty]) error
|
|
||||||
mustEmbedUnimplementedWoodpeckerServer()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnimplementedWoodpeckerServer must be embedded to have
|
|
||||||
// forward compatible implementations.
|
|
||||||
//
|
|
||||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
|
||||||
// pointer dereference when methods are called.
|
|
||||||
type UnimplementedWoodpeckerServer struct{}
|
|
||||||
|
|
||||||
func (UnimplementedWoodpeckerServer) Version(context.Context, *Empty) (*VersionResponse, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "method Version not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) Next(context.Context, *NextRequest) (*NextResponse, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "method Next not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) Init(context.Context, *InitRequest) (*Empty, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "method Init not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) Wait(context.Context, *WaitRequest) (*WaitResponse, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "method Wait not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) Done(context.Context, *DoneRequest) (*Empty, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "method Done not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) Extend(context.Context, *ExtendRequest) (*Empty, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "method Extend not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) Update(context.Context, *UpdateRequest) (*Empty, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "method Update not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) Log(context.Context, *LogRequest) (*Empty, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "method Log not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) RegisterAgent(context.Context, *RegisterAgentRequest) (*RegisterAgentResponse, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "method RegisterAgent not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) UnregisterAgent(context.Context, *Empty) (*Empty, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "method UnregisterAgent not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) ReportHealth(context.Context, *ReportHealthRequest) (*Empty, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "method ReportHealth not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) StreamJobs(*StreamOptions, grpc.ServerStreamingServer[Workflow]) error {
|
|
||||||
return status.Error(codes.Unimplemented, "method StreamJobs not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) SubmitStatus(grpc.ClientStreamingServer[StatusUpdate, Empty]) error {
|
|
||||||
return status.Error(codes.Unimplemented, "method SubmitStatus not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerServer) mustEmbedUnimplementedWoodpeckerServer() {}
|
|
||||||
func (UnimplementedWoodpeckerServer) testEmbeddedByValue() {}
|
|
||||||
|
|
||||||
// UnsafeWoodpeckerServer may be embedded to opt out of forward compatibility for this service.
|
|
||||||
// Use of this interface is not recommended, as added methods to WoodpeckerServer will
|
|
||||||
// result in compilation errors.
|
|
||||||
type UnsafeWoodpeckerServer interface {
|
|
||||||
mustEmbedUnimplementedWoodpeckerServer()
|
|
||||||
}
|
|
||||||
|
|
||||||
func RegisterWoodpeckerServer(s grpc.ServiceRegistrar, srv WoodpeckerServer) {
|
|
||||||
// If the following call panics, it indicates UnimplementedWoodpeckerServer was
|
|
||||||
// embedded by pointer and is nil. This will cause panics if an
|
|
||||||
// unimplemented method is ever invoked, so we test this at initialization
|
|
||||||
// time to prevent it from happening at runtime later due to I/O.
|
|
||||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
|
||||||
t.testEmbeddedByValue()
|
|
||||||
}
|
|
||||||
s.RegisterService(&Woodpecker_ServiceDesc, srv)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _Woodpecker_Version_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
|
||||||
in := new(Empty)
|
|
||||||
if err := dec(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if interceptor == nil {
|
|
||||||
return srv.(WoodpeckerServer).Version(ctx, in)
|
|
||||||
}
|
|
||||||
info := &grpc.UnaryServerInfo{
|
|
||||||
Server: srv,
|
|
||||||
FullMethod: Woodpecker_Version_FullMethodName,
|
|
||||||
}
|
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
||||||
return srv.(WoodpeckerServer).Version(ctx, req.(*Empty))
|
|
||||||
}
|
|
||||||
return interceptor(ctx, in, info, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _Woodpecker_Next_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
|
||||||
in := new(NextRequest)
|
|
||||||
if err := dec(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if interceptor == nil {
|
|
||||||
return srv.(WoodpeckerServer).Next(ctx, in)
|
|
||||||
}
|
|
||||||
info := &grpc.UnaryServerInfo{
|
|
||||||
Server: srv,
|
|
||||||
FullMethod: Woodpecker_Next_FullMethodName,
|
|
||||||
}
|
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
||||||
return srv.(WoodpeckerServer).Next(ctx, req.(*NextRequest))
|
|
||||||
}
|
|
||||||
return interceptor(ctx, in, info, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _Woodpecker_Init_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
|
||||||
in := new(InitRequest)
|
|
||||||
if err := dec(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if interceptor == nil {
|
|
||||||
return srv.(WoodpeckerServer).Init(ctx, in)
|
|
||||||
}
|
|
||||||
info := &grpc.UnaryServerInfo{
|
|
||||||
Server: srv,
|
|
||||||
FullMethod: Woodpecker_Init_FullMethodName,
|
|
||||||
}
|
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
||||||
return srv.(WoodpeckerServer).Init(ctx, req.(*InitRequest))
|
|
||||||
}
|
|
||||||
return interceptor(ctx, in, info, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _Woodpecker_Wait_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
|
||||||
in := new(WaitRequest)
|
|
||||||
if err := dec(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if interceptor == nil {
|
|
||||||
return srv.(WoodpeckerServer).Wait(ctx, in)
|
|
||||||
}
|
|
||||||
info := &grpc.UnaryServerInfo{
|
|
||||||
Server: srv,
|
|
||||||
FullMethod: Woodpecker_Wait_FullMethodName,
|
|
||||||
}
|
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
||||||
return srv.(WoodpeckerServer).Wait(ctx, req.(*WaitRequest))
|
|
||||||
}
|
|
||||||
return interceptor(ctx, in, info, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _Woodpecker_Done_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
|
||||||
in := new(DoneRequest)
|
|
||||||
if err := dec(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if interceptor == nil {
|
|
||||||
return srv.(WoodpeckerServer).Done(ctx, in)
|
|
||||||
}
|
|
||||||
info := &grpc.UnaryServerInfo{
|
|
||||||
Server: srv,
|
|
||||||
FullMethod: Woodpecker_Done_FullMethodName,
|
|
||||||
}
|
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
||||||
return srv.(WoodpeckerServer).Done(ctx, req.(*DoneRequest))
|
|
||||||
}
|
|
||||||
return interceptor(ctx, in, info, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _Woodpecker_Extend_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
|
||||||
in := new(ExtendRequest)
|
|
||||||
if err := dec(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if interceptor == nil {
|
|
||||||
return srv.(WoodpeckerServer).Extend(ctx, in)
|
|
||||||
}
|
|
||||||
info := &grpc.UnaryServerInfo{
|
|
||||||
Server: srv,
|
|
||||||
FullMethod: Woodpecker_Extend_FullMethodName,
|
|
||||||
}
|
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
||||||
return srv.(WoodpeckerServer).Extend(ctx, req.(*ExtendRequest))
|
|
||||||
}
|
|
||||||
return interceptor(ctx, in, info, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _Woodpecker_Update_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
|
||||||
in := new(UpdateRequest)
|
|
||||||
if err := dec(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if interceptor == nil {
|
|
||||||
return srv.(WoodpeckerServer).Update(ctx, in)
|
|
||||||
}
|
|
||||||
info := &grpc.UnaryServerInfo{
|
|
||||||
Server: srv,
|
|
||||||
FullMethod: Woodpecker_Update_FullMethodName,
|
|
||||||
}
|
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
||||||
return srv.(WoodpeckerServer).Update(ctx, req.(*UpdateRequest))
|
|
||||||
}
|
|
||||||
return interceptor(ctx, in, info, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _Woodpecker_Log_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
|
||||||
in := new(LogRequest)
|
|
||||||
if err := dec(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if interceptor == nil {
|
|
||||||
return srv.(WoodpeckerServer).Log(ctx, in)
|
|
||||||
}
|
|
||||||
info := &grpc.UnaryServerInfo{
|
|
||||||
Server: srv,
|
|
||||||
FullMethod: Woodpecker_Log_FullMethodName,
|
|
||||||
}
|
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
||||||
return srv.(WoodpeckerServer).Log(ctx, req.(*LogRequest))
|
|
||||||
}
|
|
||||||
return interceptor(ctx, in, info, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _Woodpecker_RegisterAgent_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
|
||||||
in := new(RegisterAgentRequest)
|
|
||||||
if err := dec(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if interceptor == nil {
|
|
||||||
return srv.(WoodpeckerServer).RegisterAgent(ctx, in)
|
|
||||||
}
|
|
||||||
info := &grpc.UnaryServerInfo{
|
|
||||||
Server: srv,
|
|
||||||
FullMethod: Woodpecker_RegisterAgent_FullMethodName,
|
|
||||||
}
|
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
||||||
return srv.(WoodpeckerServer).RegisterAgent(ctx, req.(*RegisterAgentRequest))
|
|
||||||
}
|
|
||||||
return interceptor(ctx, in, info, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _Woodpecker_UnregisterAgent_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
|
||||||
in := new(Empty)
|
|
||||||
if err := dec(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if interceptor == nil {
|
|
||||||
return srv.(WoodpeckerServer).UnregisterAgent(ctx, in)
|
|
||||||
}
|
|
||||||
info := &grpc.UnaryServerInfo{
|
|
||||||
Server: srv,
|
|
||||||
FullMethod: Woodpecker_UnregisterAgent_FullMethodName,
|
|
||||||
}
|
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
||||||
return srv.(WoodpeckerServer).UnregisterAgent(ctx, req.(*Empty))
|
|
||||||
}
|
|
||||||
return interceptor(ctx, in, info, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _Woodpecker_ReportHealth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
|
||||||
in := new(ReportHealthRequest)
|
|
||||||
if err := dec(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if interceptor == nil {
|
|
||||||
return srv.(WoodpeckerServer).ReportHealth(ctx, in)
|
|
||||||
}
|
|
||||||
info := &grpc.UnaryServerInfo{
|
|
||||||
Server: srv,
|
|
||||||
FullMethod: Woodpecker_ReportHealth_FullMethodName,
|
|
||||||
}
|
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
||||||
return srv.(WoodpeckerServer).ReportHealth(ctx, req.(*ReportHealthRequest))
|
|
||||||
}
|
|
||||||
return interceptor(ctx, in, info, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _Woodpecker_StreamJobs_Handler(srv interface{}, stream grpc.ServerStream) error {
|
|
||||||
m := new(StreamOptions)
|
|
||||||
if err := stream.RecvMsg(m); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return srv.(WoodpeckerServer).StreamJobs(m, &grpc.GenericServerStream[StreamOptions, Workflow]{ServerStream: stream})
|
|
||||||
}
|
|
||||||
|
|
||||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
|
||||||
type Woodpecker_StreamJobsServer = grpc.ServerStreamingServer[Workflow]
|
|
||||||
|
|
||||||
func _Woodpecker_SubmitStatus_Handler(srv interface{}, stream grpc.ServerStream) error {
|
|
||||||
return srv.(WoodpeckerServer).SubmitStatus(&grpc.GenericServerStream[StatusUpdate, Empty]{ServerStream: stream})
|
|
||||||
}
|
|
||||||
|
|
||||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
|
||||||
type Woodpecker_SubmitStatusServer = grpc.ClientStreamingServer[StatusUpdate, Empty]
|
|
||||||
|
|
||||||
// Woodpecker_ServiceDesc is the grpc.ServiceDesc for Woodpecker service.
|
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
|
||||||
// and not to be introspected or modified (even as a copy)
|
|
||||||
var Woodpecker_ServiceDesc = grpc.ServiceDesc{
|
|
||||||
ServiceName: "proto.Woodpecker",
|
|
||||||
HandlerType: (*WoodpeckerServer)(nil),
|
|
||||||
Methods: []grpc.MethodDesc{
|
|
||||||
{
|
|
||||||
MethodName: "Version",
|
|
||||||
Handler: _Woodpecker_Version_Handler,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
MethodName: "Next",
|
|
||||||
Handler: _Woodpecker_Next_Handler,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
MethodName: "Init",
|
|
||||||
Handler: _Woodpecker_Init_Handler,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
MethodName: "Wait",
|
|
||||||
Handler: _Woodpecker_Wait_Handler,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
MethodName: "Done",
|
|
||||||
Handler: _Woodpecker_Done_Handler,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
MethodName: "Extend",
|
|
||||||
Handler: _Woodpecker_Extend_Handler,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
MethodName: "Update",
|
|
||||||
Handler: _Woodpecker_Update_Handler,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
MethodName: "Log",
|
|
||||||
Handler: _Woodpecker_Log_Handler,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
MethodName: "RegisterAgent",
|
|
||||||
Handler: _Woodpecker_RegisterAgent_Handler,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
MethodName: "UnregisterAgent",
|
|
||||||
Handler: _Woodpecker_UnregisterAgent_Handler,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
MethodName: "ReportHealth",
|
|
||||||
Handler: _Woodpecker_ReportHealth_Handler,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Streams: []grpc.StreamDesc{
|
|
||||||
{
|
|
||||||
StreamName: "StreamJobs",
|
|
||||||
Handler: _Woodpecker_StreamJobs_Handler,
|
|
||||||
ServerStreams: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
StreamName: "SubmitStatus",
|
|
||||||
Handler: _Woodpecker_SubmitStatus_Handler,
|
|
||||||
ClientStreams: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Metadata: "proto/woodpecker.proto",
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
WoodpeckerAuth_Auth_FullMethodName = "/proto.WoodpeckerAuth/Auth"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WoodpeckerAuthClient is the client API for WoodpeckerAuth service.
|
|
||||||
//
|
|
||||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
|
||||||
type WoodpeckerAuthClient interface {
|
|
||||||
Auth(ctx context.Context, in *AuthRequest, opts ...grpc.CallOption) (*AuthResponse, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type woodpeckerAuthClient struct {
|
|
||||||
cc grpc.ClientConnInterface
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWoodpeckerAuthClient(cc grpc.ClientConnInterface) WoodpeckerAuthClient {
|
|
||||||
return &woodpeckerAuthClient{cc}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *woodpeckerAuthClient) Auth(ctx context.Context, in *AuthRequest, opts ...grpc.CallOption) (*AuthResponse, error) {
|
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
|
||||||
out := new(AuthResponse)
|
|
||||||
err := c.cc.Invoke(ctx, WoodpeckerAuth_Auth_FullMethodName, in, out, cOpts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WoodpeckerAuthServer is the server API for WoodpeckerAuth service.
|
|
||||||
// All implementations must embed UnimplementedWoodpeckerAuthServer
|
|
||||||
// for forward compatibility.
|
|
||||||
type WoodpeckerAuthServer interface {
|
|
||||||
Auth(context.Context, *AuthRequest) (*AuthResponse, error)
|
|
||||||
mustEmbedUnimplementedWoodpeckerAuthServer()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnimplementedWoodpeckerAuthServer must be embedded to have
|
|
||||||
// forward compatible implementations.
|
|
||||||
//
|
|
||||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
|
||||||
// pointer dereference when methods are called.
|
|
||||||
type UnimplementedWoodpeckerAuthServer struct{}
|
|
||||||
|
|
||||||
func (UnimplementedWoodpeckerAuthServer) Auth(context.Context, *AuthRequest) (*AuthResponse, error) {
|
|
||||||
return nil, status.Error(codes.Unimplemented, "method Auth not implemented")
|
|
||||||
}
|
|
||||||
func (UnimplementedWoodpeckerAuthServer) mustEmbedUnimplementedWoodpeckerAuthServer() {}
|
|
||||||
func (UnimplementedWoodpeckerAuthServer) testEmbeddedByValue() {}
|
|
||||||
|
|
||||||
// UnsafeWoodpeckerAuthServer may be embedded to opt out of forward compatibility for this service.
|
|
||||||
// Use of this interface is not recommended, as added methods to WoodpeckerAuthServer will
|
|
||||||
// result in compilation errors.
|
|
||||||
type UnsafeWoodpeckerAuthServer interface {
|
|
||||||
mustEmbedUnimplementedWoodpeckerAuthServer()
|
|
||||||
}
|
|
||||||
|
|
||||||
func RegisterWoodpeckerAuthServer(s grpc.ServiceRegistrar, srv WoodpeckerAuthServer) {
|
|
||||||
// If the following call panics, it indicates UnimplementedWoodpeckerAuthServer was
|
|
||||||
// embedded by pointer and is nil. This will cause panics if an
|
|
||||||
// unimplemented method is ever invoked, so we test this at initialization
|
|
||||||
// time to prevent it from happening at runtime later due to I/O.
|
|
||||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
|
||||||
t.testEmbeddedByValue()
|
|
||||||
}
|
|
||||||
s.RegisterService(&WoodpeckerAuth_ServiceDesc, srv)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _WoodpeckerAuth_Auth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
|
||||||
in := new(AuthRequest)
|
|
||||||
if err := dec(in); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if interceptor == nil {
|
|
||||||
return srv.(WoodpeckerAuthServer).Auth(ctx, in)
|
|
||||||
}
|
|
||||||
info := &grpc.UnaryServerInfo{
|
|
||||||
Server: srv,
|
|
||||||
FullMethod: WoodpeckerAuth_Auth_FullMethodName,
|
|
||||||
}
|
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
||||||
return srv.(WoodpeckerAuthServer).Auth(ctx, req.(*AuthRequest))
|
|
||||||
}
|
|
||||||
return interceptor(ctx, in, info, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
// WoodpeckerAuth_ServiceDesc is the grpc.ServiceDesc for WoodpeckerAuth service.
|
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
|
||||||
// and not to be introspected or modified (even as a copy)
|
|
||||||
var WoodpeckerAuth_ServiceDesc = grpc.ServiceDesc{
|
|
||||||
ServiceName: "proto.WoodpeckerAuth",
|
|
||||||
HandlerType: (*WoodpeckerAuthServer)(nil),
|
|
||||||
Methods: []grpc.MethodDesc{
|
|
||||||
{
|
|
||||||
MethodName: "Auth",
|
|
||||||
Handler: _WoodpeckerAuth_Auth_Handler,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Streams: []grpc.StreamDesc{},
|
|
||||||
Metadata: "proto/woodpecker.proto",
|
|
||||||
}
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
package services
|
|
||||||
|
|
||||||
import "stream.api/internal/video/runtime/domain"
|
|
||||||
|
|
||||||
type AgentWithStats struct {
|
|
||||||
*domain.Agent
|
|
||||||
ActiveJobCount int64 `json:"active_job_count"`
|
|
||||||
}
|
|
||||||
119
pkg/token/jwt.go
119
pkg/token/jwt.go
@@ -1,119 +0,0 @@
|
|||||||
package token
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
// jwtClaims is an internal struct to satisfy jwt.Claims interface
|
|
||||||
type jwtClaims struct {
|
|
||||||
UserID string `json:"user_id"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
Role string `json:"role"`
|
|
||||||
TokenID string `json:"token_id"`
|
|
||||||
jwt.RegisteredClaims
|
|
||||||
}
|
|
||||||
|
|
||||||
// jwt.go implements the Provider interface using JWT
|
|
||||||
|
|
||||||
type jwtProvider struct {
|
|
||||||
secret string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewJWTProvider creates a new instance of JWT provider
|
|
||||||
func NewJWTProvider(secret string) Provider {
|
|
||||||
return &jwtProvider{
|
|
||||||
secret: secret,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateTokenPair generates new access and refresh tokens
|
|
||||||
func (p *jwtProvider) GenerateTokenPair(userID, email, role string) (*TokenPair, error) {
|
|
||||||
td := &TokenPair{}
|
|
||||||
td.AtExpires = time.Now().Add(time.Minute * 15).Unix()
|
|
||||||
td.AccessUUID = uuid.New().String()
|
|
||||||
|
|
||||||
td.RtExpires = time.Now().Add(time.Hour * 24 * 7).Unix() // Expires in 7 days
|
|
||||||
td.RefreshUUID = uuid.New().String()
|
|
||||||
|
|
||||||
// Access Token
|
|
||||||
atClaims := &jwtClaims{
|
|
||||||
UserID: userID,
|
|
||||||
Email: email,
|
|
||||||
Role: role,
|
|
||||||
TokenID: td.AccessUUID,
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
ExpiresAt: jwt.NewNumericDate(time.Unix(td.AtExpires, 0)),
|
|
||||||
Issuer: "stream.api",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
at := jwt.NewWithClaims(jwt.SigningMethodHS256, atClaims)
|
|
||||||
var err error
|
|
||||||
td.AccessToken, err = at.SignedString([]byte(p.secret))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Refresh Token
|
|
||||||
// Refresh token can just be a random string or a JWT.
|
|
||||||
// Common practice: JWT for stateless verification, or Opaque string for stateful.
|
|
||||||
// Here we use JWT so we can carry some metadata if needed, but we check Redis anyway.
|
|
||||||
rtClaims := jwt.MapClaims{}
|
|
||||||
rtClaims["refresh_uuid"] = td.RefreshUUID
|
|
||||||
rtClaims["user_id"] = userID
|
|
||||||
rtClaims["exp"] = td.RtExpires
|
|
||||||
rt := jwt.NewWithClaims(jwt.SigningMethodHS256, rtClaims)
|
|
||||||
td.RefreshToken, err = rt.SignedString([]byte(p.secret))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return td, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseToken parses the access token returning Claims
|
|
||||||
func (p *jwtProvider) ParseToken(tokenString string) (*Claims, error) {
|
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &jwtClaims{}, func(token *jwt.Token) (interface{}, error) {
|
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
||||||
}
|
|
||||||
return []byte(p.secret), nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if claims, ok := token.Claims.(*jwtClaims); ok && token.Valid {
|
|
||||||
return &Claims{
|
|
||||||
UserID: claims.UserID,
|
|
||||||
Email: claims.Email,
|
|
||||||
Role: claims.Role,
|
|
||||||
TokenID: claims.TokenID,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("invalid token")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseMapToken parses token returning map[string]interface{} (generic)
|
|
||||||
func (p *jwtProvider) ParseMapToken(tokenString string) (map[string]interface{}, error) {
|
|
||||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
||||||
}
|
|
||||||
return []byte(p.secret), nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
|
||||||
return claims, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, jwt.ErrTokenInvalidClaims
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user