Files
stream.api/internal/workflow/agent/agent.go
2026-04-02 11:01:30 +00:00

577 lines
14 KiB
Go

package agent
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"strconv"
"strings"
"sync"
"time"
grpcpkg "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
proto "stream.api/internal/api/proto/agent/v1"
)
type Agent struct {
client proto.WoodpeckerClient
authClient proto.WoodpeckerAuthClient
conn *grpcpkg.ClientConn
secret string
token string
capacity int
agentID string
docker *DockerExecutor
semaphore chan struct{}
wg sync.WaitGroup
activeJobs sync.Map
prevCPUTotal uint64
prevCPUIdle uint64
}
type JobPayload struct {
Image string `json:"image"`
Commands []string `json:"commands"`
Environment map[string]string `json:"environment"`
Action string `json:"action"`
}
func New(serverAddr, secret string, capacity int) (*Agent, error) {
conn, err := grpcpkg.NewClient(serverAddr, grpcpkg.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, fmt.Errorf("failed to connect to server: %w", err)
}
docker, err := NewDockerExecutor()
if err != nil {
return nil, fmt.Errorf("failed to initialize docker executor: %w", err)
}
return &Agent{
client: proto.NewWoodpeckerClient(conn),
authClient: proto.NewWoodpeckerAuthClient(conn),
conn: conn,
secret: secret,
capacity: capacity,
docker: docker,
semaphore: make(chan struct{}, capacity),
}, nil
}
func (a *Agent) Run(ctx context.Context) error {
defer func() {
unregisterCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_ = a.unregister(unregisterCtx)
_ = a.conn.Close()
}()
for {
if ctx.Err() != nil {
return ctx.Err()
}
if err := a.registerWithRetry(ctx); err != nil {
log.Printf("Registration failed hard: %v", err)
return err
}
log.Printf("Agent started/reconnected with ID: %s, Capacity: %d", a.agentID, a.capacity)
sessionCtx, sessionCancel := context.WithCancel(ctx)
a.startBackgroundRoutines(sessionCtx)
err := a.streamJobs(sessionCtx)
sessionCancel()
a.wg.Wait()
if ctx.Err() != nil {
return ctx.Err()
}
log.Printf("Session ended: %v. Re-registering in 5s...", err)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(5 * time.Second):
}
}
}
func (a *Agent) registerWithRetry(ctx context.Context) error {
backoff := 1 * time.Second
maxBackoff := 30 * time.Second
for {
if err := a.register(ctx); err != nil {
log.Printf("Registration failed: %v. Retrying in %v...", err, backoff)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff):
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
}
}
return nil
}
}
func (a *Agent) startBackgroundRoutines(ctx context.Context) {
go a.cancelListener(ctx)
go a.submitStatusLoop(ctx)
}
func (a *Agent) cancelListener(context.Context) {}
func (a *Agent) register(ctx context.Context) error {
var savedID string
if os.Getenv("FORCE_NEW_ID") != "true" {
var err error
savedID, err = a.loadAgentID()
if err == nil && savedID != "" {
log.Printf("Loaded persisted Agent ID: %s", savedID)
a.agentID = savedID
}
} else {
log.Println("Forcing new Agent ID due to FORCE_NEW_ID=true")
}
authResp, err := a.authClient.Auth(ctx, &proto.AuthRequest{
AgentToken: a.secret,
AgentId: a.agentID,
Hostname: a.getHostFingerprint(),
})
if err != nil {
return fmt.Errorf("auth failed: %w", err)
}
a.agentID = authResp.AgentId
if a.agentID != savedID {
if err := a.saveAgentID(a.agentID); err != nil {
log.Printf("Failed to save agent ID: %v", err)
} else {
log.Printf("Persisted Agent ID: %s", a.agentID)
}
}
a.token = authResp.AccessToken
mdCtx := metadata.AppendToOutgoingContext(ctx, "token", a.token)
hostname := a.getHostFingerprint()
_, err = a.client.RegisterAgent(mdCtx, &proto.RegisterAgentRequest{
Info: &proto.AgentInfo{
Platform: "linux/amd64",
Backend: "ffmpeg",
Version: "stream-api-agent-v1",
Capacity: int32(a.capacity),
CustomLabels: map[string]string{
"hostname": hostname,
},
},
})
if err != nil {
return fmt.Errorf("registration failed: %w", err)
}
return nil
}
func (a *Agent) withToken(ctx context.Context) context.Context {
return metadata.AppendToOutgoingContext(ctx, "token", a.token)
}
func (a *Agent) streamJobs(ctx context.Context) error {
mdCtx := a.withToken(ctx)
hostname, err := os.Hostname()
if err != nil {
hostname = "unknown-agent"
}
stream, err := a.client.StreamJobs(mdCtx, &proto.StreamOptions{
Filter: &proto.Filter{
Labels: map[string]string{
"hostname": hostname,
},
},
})
if err != nil {
return fmt.Errorf("failed to start job stream: %w", err)
}
log.Println("Connected to job stream")
for {
select {
case a.semaphore <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
case <-stream.Context().Done():
return stream.Context().Err()
}
workflow, err := stream.Recv()
if err != nil {
<-a.semaphore
return fmt.Errorf("stream closed or error: %w", err)
}
if workflow.Cancel {
<-a.semaphore
log.Printf("Received cancellation signal for job %s", workflow.Id)
if found := a.CancelJob(workflow.Id); found {
log.Printf("Job %s cancellation triggered", workflow.Id)
} else {
log.Printf("Job %s not found in active jobs", workflow.Id)
}
continue
}
log.Printf("Received job from stream: %s (active: %d/%d)", workflow.Id, len(a.semaphore), a.capacity)
a.wg.Add(1)
go func(wf *proto.Workflow) {
defer a.wg.Done()
defer func() { <-a.semaphore }()
a.executeJob(ctx, wf)
}(workflow)
}
}
func (a *Agent) submitStatusLoop(ctx context.Context) {
for {
if err := a.runStatusStream(ctx); err != nil {
log.Printf("Status stream error: %v. Retrying in 5s...", err)
select {
case <-ctx.Done():
return
case <-time.After(5 * time.Second):
continue
}
}
return
}
}
func (a *Agent) runStatusStream(ctx context.Context) error {
mdCtx := a.withToken(ctx)
stream, err := a.client.SubmitStatus(mdCtx)
if err != nil {
return err
}
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
_, _ = stream.CloseAndRecv()
return ctx.Err()
case <-ticker.C:
cpu, ram := a.collectSystemResources()
data := fmt.Sprintf(`{"cpu": %.2f, "ram": %.2f}`, cpu, ram)
if err := stream.Send(&proto.StatusUpdate{Type: 5, Time: time.Now().Unix(), Data: []byte(data)}); err != nil {
return err
}
}
}
}
func (a *Agent) collectSystemResources() (float64, float64) {
var memTotal, memAvailable uint64
data, err := os.ReadFile("/proc/meminfo")
if err == nil {
lines := strings.Split(string(data), "\n")
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
switch fields[0] {
case "MemTotal:":
memTotal, _ = strconv.ParseUint(fields[1], 10, 64)
case "MemAvailable:":
memAvailable, _ = strconv.ParseUint(fields[1], 10, 64)
}
}
}
usedRAM := 0.0
if memTotal > 0 {
usedRAM = float64(memTotal-memAvailable) / 1024.0
}
var cpuUsage float64
data, err = os.ReadFile("/proc/stat")
if err == nil {
lines := strings.Split(string(data), "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "cpu ") {
continue
}
fields := strings.Fields(line)
if len(fields) < 5 {
break
}
var user, nice, system, idle, iowait, irq, softirq, steal uint64
user, _ = strconv.ParseUint(fields[1], 10, 64)
nice, _ = strconv.ParseUint(fields[2], 10, 64)
system, _ = strconv.ParseUint(fields[3], 10, 64)
idle, _ = strconv.ParseUint(fields[4], 10, 64)
if len(fields) > 5 {
iowait, _ = strconv.ParseUint(fields[5], 10, 64)
}
if len(fields) > 6 {
irq, _ = strconv.ParseUint(fields[6], 10, 64)
}
if len(fields) > 7 {
softirq, _ = strconv.ParseUint(fields[7], 10, 64)
}
if len(fields) > 8 {
steal, _ = strconv.ParseUint(fields[8], 10, 64)
}
currentIdle := idle + iowait
currentNonIdle := user + nice + system + irq + softirq + steal
currentTotal := currentIdle + currentNonIdle
totalDiff := currentTotal - a.prevCPUTotal
idleDiff := currentIdle - a.prevCPUIdle
if totalDiff > 0 && a.prevCPUTotal > 0 {
cpuUsage = float64(totalDiff-idleDiff) / float64(totalDiff) * 100.0
}
a.prevCPUTotal = currentTotal
a.prevCPUIdle = currentIdle
break
}
}
return cpuUsage, usedRAM
}
func (a *Agent) executeJob(ctx context.Context, workflow *proto.Workflow) {
log.Printf("Executing job %s", workflow.Id)
jobCtx, jobCancel := context.WithCancel(ctx)
defer jobCancel()
if workflow.Timeout > 0 {
timeoutDuration := time.Duration(workflow.Timeout) * time.Second
log.Printf("Job %s has timeout of %v", workflow.Id, timeoutDuration)
jobCtx, jobCancel = context.WithTimeout(jobCtx, timeoutDuration)
defer jobCancel()
}
a.activeJobs.Store(workflow.Id, jobCancel)
defer a.activeJobs.Delete(workflow.Id)
var payload JobPayload
if err := json.Unmarshal(workflow.Payload, &payload); err != nil {
log.Printf("Failed to parse payload for job %s: %v", workflow.Id, err)
a.reportDone(ctx, workflow.Id, fmt.Sprintf("invalid payload: %v", err))
return
}
if payload.Action != "" {
log.Printf("Received system command: %s", payload.Action)
a.reportDone(ctx, workflow.Id, "")
switch payload.Action {
case "restart":
log.Println("Restarting agent...")
os.Exit(0)
case "update":
log.Println("Updating agent...")
imageName := os.Getenv("AGENT_IMAGE")
if imageName == "" {
imageName = "stream-api-agent:latest"
}
if err := a.docker.SelfUpdate(context.Background(), imageName, a.agentID); err != nil {
log.Printf("Update failed: %v", err)
} else {
os.Exit(0)
}
}
return
}
mdCtx := a.withToken(ctx)
if _, err := a.client.Init(mdCtx, &proto.InitRequest{Id: workflow.Id}); err != nil {
log.Printf("Failed to init job %s: %v", workflow.Id, err)
return
}
log.Printf("Running container with image: %s", payload.Image)
done := make(chan error, 1)
go a.extendLoop(jobCtx, workflow.Id)
go func() {
done <- a.docker.Run(jobCtx, payload.Image, payload.Commands, payload.Environment, func(line string) {
progress := -1.0
if val, ok := parseProgress(line); ok {
progress = val
}
entries := []*proto.LogEntry{{
StepUuid: workflow.Id,
Data: []byte(line),
Time: time.Now().Unix(),
Type: 1,
}}
if progress >= 0 {
entries = append(entries, &proto.LogEntry{
StepUuid: workflow.Id,
Time: time.Now().Unix(),
Type: 4,
Data: []byte(fmt.Sprintf("%f", progress)),
})
}
if _, err := a.client.Log(mdCtx, &proto.LogRequest{LogEntries: entries}); err != nil {
log.Printf("Failed to send log for job %s: %v", workflow.Id, err)
}
})
}()
var err error
select {
case err = <-done:
case <-jobCtx.Done():
if jobCtx.Err() == context.DeadlineExceeded {
err = fmt.Errorf("job timeout exceeded")
log.Printf("Job %s timed out", workflow.Id)
} else {
err = fmt.Errorf("job cancelled")
log.Printf("Job %s was cancelled", workflow.Id)
}
}
if err != nil {
log.Printf("Job %s failed: %v", workflow.Id, err)
a.reportDone(ctx, workflow.Id, err.Error())
} else {
log.Printf("Job %s succeeded", workflow.Id)
a.reportDone(ctx, workflow.Id, "")
}
}
func (a *Agent) CancelJob(jobID string) bool {
if cancelFunc, ok := a.activeJobs.Load(jobID); ok {
log.Printf("Cancelling job %s", jobID)
cancelFunc.(context.CancelFunc)()
return true
}
return false
}
func (a *Agent) extendLoop(ctx context.Context, jobID string) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
mdCtx := a.withToken(ctx)
if _, err := a.client.Extend(mdCtx, &proto.ExtendRequest{Id: jobID}); err != nil {
log.Printf("Failed to extend lease for job %s: %v", jobID, err)
}
}
}
}
func (a *Agent) reportDone(_ context.Context, id string, errStr string) {
state := &proto.WorkflowState{Finished: time.Now().Unix()}
if errStr != "" {
state.Error = errStr
}
reportCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
mdCtx := a.withToken(reportCtx)
_, err := a.client.Done(mdCtx, &proto.DoneRequest{Id: id, State: state})
if err != nil {
log.Printf("Failed to report Done for job %s: %v", id, err)
}
}
func (a *Agent) unregister(ctx context.Context) error {
if a.token == "" {
return nil
}
mdCtx := a.withToken(ctx)
_, err := a.client.UnregisterAgent(mdCtx, &proto.Empty{})
if err != nil {
return fmt.Errorf("failed to unregister agent: %w", err)
}
return nil
}
const (
AgentIDFile = "/data/agent_id"
HostnameFile = "/host_hostname"
)
type AgentIdentity struct {
ID string `json:"id"`
Fingerprint string `json:"fingerprint"`
}
func (a *Agent) getHostFingerprint() string {
data, err := os.ReadFile(HostnameFile)
if err == nil {
return strings.TrimSpace(string(data))
}
hostname, _ := os.Hostname()
return hostname
}
func (a *Agent) loadAgentID() (string, error) {
data, err := os.ReadFile(AgentIDFile)
if err != nil {
return "", err
}
var identity AgentIdentity
if err := json.Unmarshal(data, &identity); err == nil {
currentFP := a.getHostFingerprint()
if identity.Fingerprint != "" && identity.Fingerprint != currentFP {
log.Printf("Environment changed (Hostname mismatch: saved=%s, current=%s). Resetting Agent ID.", identity.Fingerprint, currentFP)
return "", fmt.Errorf("environment changed")
}
return identity.ID, nil
}
id := strings.TrimSpace(string(data))
if id == "" {
return "", fmt.Errorf("empty ID")
}
return id, nil
}
func (a *Agent) saveAgentID(id string) error {
if err := os.MkdirAll("/data", 0755); err != nil {
return err
}
identity := AgentIdentity{ID: id, Fingerprint: a.getHostFingerprint()}
data, err := json.Marshal(identity)
if err != nil {
return err
}
return os.WriteFile(AgentIDFile, data, 0644)
}