577 lines
14 KiB
Go
577 lines
14 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
grpcpkg "google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/metadata"
|
|
proto "stream.api/internal/api/proto/agent/v1"
|
|
)
|
|
|
|
type Agent struct {
|
|
client proto.WoodpeckerClient
|
|
authClient proto.WoodpeckerAuthClient
|
|
conn *grpcpkg.ClientConn
|
|
secret string
|
|
token string
|
|
capacity int
|
|
agentID string
|
|
docker *DockerExecutor
|
|
semaphore chan struct{}
|
|
wg sync.WaitGroup
|
|
activeJobs sync.Map
|
|
prevCPUTotal uint64
|
|
prevCPUIdle uint64
|
|
}
|
|
|
|
type JobPayload struct {
|
|
Image string `json:"image"`
|
|
Commands []string `json:"commands"`
|
|
Environment map[string]string `json:"environment"`
|
|
Action string `json:"action"`
|
|
}
|
|
|
|
func New(serverAddr, secret string, capacity int) (*Agent, error) {
|
|
conn, err := grpcpkg.NewClient(serverAddr, grpcpkg.WithTransportCredentials(insecure.NewCredentials()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to server: %w", err)
|
|
}
|
|
|
|
docker, err := NewDockerExecutor()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to initialize docker executor: %w", err)
|
|
}
|
|
|
|
return &Agent{
|
|
client: proto.NewWoodpeckerClient(conn),
|
|
authClient: proto.NewWoodpeckerAuthClient(conn),
|
|
conn: conn,
|
|
secret: secret,
|
|
capacity: capacity,
|
|
docker: docker,
|
|
semaphore: make(chan struct{}, capacity),
|
|
}, nil
|
|
}
|
|
|
|
func (a *Agent) Run(ctx context.Context) error {
|
|
defer func() {
|
|
unregisterCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
_ = a.unregister(unregisterCtx)
|
|
_ = a.conn.Close()
|
|
}()
|
|
|
|
for {
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
|
|
if err := a.registerWithRetry(ctx); err != nil {
|
|
log.Printf("Registration failed hard: %v", err)
|
|
return err
|
|
}
|
|
|
|
log.Printf("Agent started/reconnected with ID: %s, Capacity: %d", a.agentID, a.capacity)
|
|
|
|
sessionCtx, sessionCancel := context.WithCancel(ctx)
|
|
a.startBackgroundRoutines(sessionCtx)
|
|
err := a.streamJobs(sessionCtx)
|
|
sessionCancel()
|
|
a.wg.Wait()
|
|
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
|
|
log.Printf("Session ended: %v. Re-registering in 5s...", err)
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(5 * time.Second):
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *Agent) registerWithRetry(ctx context.Context) error {
|
|
backoff := 1 * time.Second
|
|
maxBackoff := 30 * time.Second
|
|
|
|
for {
|
|
if err := a.register(ctx); err != nil {
|
|
log.Printf("Registration failed: %v. Retrying in %v...", err, backoff)
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(backoff):
|
|
backoff *= 2
|
|
if backoff > maxBackoff {
|
|
backoff = maxBackoff
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (a *Agent) startBackgroundRoutines(ctx context.Context) {
|
|
go a.cancelListener(ctx)
|
|
go a.submitStatusLoop(ctx)
|
|
}
|
|
|
|
func (a *Agent) cancelListener(context.Context) {}
|
|
|
|
func (a *Agent) register(ctx context.Context) error {
|
|
var savedID string
|
|
if os.Getenv("FORCE_NEW_ID") != "true" {
|
|
var err error
|
|
savedID, err = a.loadAgentID()
|
|
if err == nil && savedID != "" {
|
|
log.Printf("Loaded persisted Agent ID: %s", savedID)
|
|
a.agentID = savedID
|
|
}
|
|
} else {
|
|
log.Println("Forcing new Agent ID due to FORCE_NEW_ID=true")
|
|
}
|
|
|
|
authResp, err := a.authClient.Auth(ctx, &proto.AuthRequest{
|
|
AgentToken: a.secret,
|
|
AgentId: a.agentID,
|
|
Hostname: a.getHostFingerprint(),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("auth failed: %w", err)
|
|
}
|
|
a.agentID = authResp.AgentId
|
|
|
|
if a.agentID != savedID {
|
|
if err := a.saveAgentID(a.agentID); err != nil {
|
|
log.Printf("Failed to save agent ID: %v", err)
|
|
} else {
|
|
log.Printf("Persisted Agent ID: %s", a.agentID)
|
|
}
|
|
}
|
|
|
|
a.token = authResp.AccessToken
|
|
mdCtx := metadata.AppendToOutgoingContext(ctx, "token", a.token)
|
|
hostname := a.getHostFingerprint()
|
|
|
|
_, err = a.client.RegisterAgent(mdCtx, &proto.RegisterAgentRequest{
|
|
Info: &proto.AgentInfo{
|
|
Platform: "linux/amd64",
|
|
Backend: "ffmpeg",
|
|
Version: "stream-api-agent-v1",
|
|
Capacity: int32(a.capacity),
|
|
CustomLabels: map[string]string{
|
|
"hostname": hostname,
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("registration failed: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *Agent) withToken(ctx context.Context) context.Context {
|
|
return metadata.AppendToOutgoingContext(ctx, "token", a.token)
|
|
}
|
|
|
|
func (a *Agent) streamJobs(ctx context.Context) error {
|
|
mdCtx := a.withToken(ctx)
|
|
hostname, err := os.Hostname()
|
|
if err != nil {
|
|
hostname = "unknown-agent"
|
|
}
|
|
|
|
stream, err := a.client.StreamJobs(mdCtx, &proto.StreamOptions{
|
|
Filter: &proto.Filter{
|
|
Labels: map[string]string{
|
|
"hostname": hostname,
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start job stream: %w", err)
|
|
}
|
|
|
|
log.Println("Connected to job stream")
|
|
|
|
for {
|
|
select {
|
|
case a.semaphore <- struct{}{}:
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-stream.Context().Done():
|
|
return stream.Context().Err()
|
|
}
|
|
|
|
workflow, err := stream.Recv()
|
|
if err != nil {
|
|
<-a.semaphore
|
|
return fmt.Errorf("stream closed or error: %w", err)
|
|
}
|
|
|
|
if workflow.Cancel {
|
|
<-a.semaphore
|
|
log.Printf("Received cancellation signal for job %s", workflow.Id)
|
|
if found := a.CancelJob(workflow.Id); found {
|
|
log.Printf("Job %s cancellation triggered", workflow.Id)
|
|
} else {
|
|
log.Printf("Job %s not found in active jobs", workflow.Id)
|
|
}
|
|
continue
|
|
}
|
|
|
|
log.Printf("Received job from stream: %s (active: %d/%d)", workflow.Id, len(a.semaphore), a.capacity)
|
|
a.wg.Add(1)
|
|
go func(wf *proto.Workflow) {
|
|
defer a.wg.Done()
|
|
defer func() { <-a.semaphore }()
|
|
a.executeJob(ctx, wf)
|
|
}(workflow)
|
|
}
|
|
}
|
|
|
|
func (a *Agent) submitStatusLoop(ctx context.Context) {
|
|
for {
|
|
if err := a.runStatusStream(ctx); err != nil {
|
|
log.Printf("Status stream error: %v. Retrying in 5s...", err)
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(5 * time.Second):
|
|
continue
|
|
}
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
func (a *Agent) runStatusStream(ctx context.Context) error {
|
|
mdCtx := a.withToken(ctx)
|
|
stream, err := a.client.SubmitStatus(mdCtx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ticker := time.NewTicker(5 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
_, _ = stream.CloseAndRecv()
|
|
return ctx.Err()
|
|
case <-ticker.C:
|
|
cpu, ram := a.collectSystemResources()
|
|
data := fmt.Sprintf(`{"cpu": %.2f, "ram": %.2f}`, cpu, ram)
|
|
if err := stream.Send(&proto.StatusUpdate{Type: 5, Time: time.Now().Unix(), Data: []byte(data)}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *Agent) collectSystemResources() (float64, float64) {
|
|
var memTotal, memAvailable uint64
|
|
|
|
data, err := os.ReadFile("/proc/meminfo")
|
|
if err == nil {
|
|
lines := strings.Split(string(data), "\n")
|
|
for _, line := range lines {
|
|
fields := strings.Fields(line)
|
|
if len(fields) < 2 {
|
|
continue
|
|
}
|
|
switch fields[0] {
|
|
case "MemTotal:":
|
|
memTotal, _ = strconv.ParseUint(fields[1], 10, 64)
|
|
case "MemAvailable:":
|
|
memAvailable, _ = strconv.ParseUint(fields[1], 10, 64)
|
|
}
|
|
}
|
|
}
|
|
|
|
usedRAM := 0.0
|
|
if memTotal > 0 {
|
|
usedRAM = float64(memTotal-memAvailable) / 1024.0
|
|
}
|
|
|
|
var cpuUsage float64
|
|
data, err = os.ReadFile("/proc/stat")
|
|
if err == nil {
|
|
lines := strings.Split(string(data), "\n")
|
|
for _, line := range lines {
|
|
if !strings.HasPrefix(line, "cpu ") {
|
|
continue
|
|
}
|
|
fields := strings.Fields(line)
|
|
if len(fields) < 5 {
|
|
break
|
|
}
|
|
|
|
var user, nice, system, idle, iowait, irq, softirq, steal uint64
|
|
user, _ = strconv.ParseUint(fields[1], 10, 64)
|
|
nice, _ = strconv.ParseUint(fields[2], 10, 64)
|
|
system, _ = strconv.ParseUint(fields[3], 10, 64)
|
|
idle, _ = strconv.ParseUint(fields[4], 10, 64)
|
|
if len(fields) > 5 {
|
|
iowait, _ = strconv.ParseUint(fields[5], 10, 64)
|
|
}
|
|
if len(fields) > 6 {
|
|
irq, _ = strconv.ParseUint(fields[6], 10, 64)
|
|
}
|
|
if len(fields) > 7 {
|
|
softirq, _ = strconv.ParseUint(fields[7], 10, 64)
|
|
}
|
|
if len(fields) > 8 {
|
|
steal, _ = strconv.ParseUint(fields[8], 10, 64)
|
|
}
|
|
|
|
currentIdle := idle + iowait
|
|
currentNonIdle := user + nice + system + irq + softirq + steal
|
|
currentTotal := currentIdle + currentNonIdle
|
|
totalDiff := currentTotal - a.prevCPUTotal
|
|
idleDiff := currentIdle - a.prevCPUIdle
|
|
|
|
if totalDiff > 0 && a.prevCPUTotal > 0 {
|
|
cpuUsage = float64(totalDiff-idleDiff) / float64(totalDiff) * 100.0
|
|
}
|
|
|
|
a.prevCPUTotal = currentTotal
|
|
a.prevCPUIdle = currentIdle
|
|
break
|
|
}
|
|
}
|
|
|
|
return cpuUsage, usedRAM
|
|
}
|
|
|
|
func (a *Agent) executeJob(ctx context.Context, workflow *proto.Workflow) {
|
|
log.Printf("Executing job %s", workflow.Id)
|
|
|
|
jobCtx, jobCancel := context.WithCancel(ctx)
|
|
defer jobCancel()
|
|
|
|
if workflow.Timeout > 0 {
|
|
timeoutDuration := time.Duration(workflow.Timeout) * time.Second
|
|
log.Printf("Job %s has timeout of %v", workflow.Id, timeoutDuration)
|
|
jobCtx, jobCancel = context.WithTimeout(jobCtx, timeoutDuration)
|
|
defer jobCancel()
|
|
}
|
|
|
|
a.activeJobs.Store(workflow.Id, jobCancel)
|
|
defer a.activeJobs.Delete(workflow.Id)
|
|
|
|
var payload JobPayload
|
|
if err := json.Unmarshal(workflow.Payload, &payload); err != nil {
|
|
log.Printf("Failed to parse payload for job %s: %v", workflow.Id, err)
|
|
a.reportDone(ctx, workflow.Id, fmt.Sprintf("invalid payload: %v", err))
|
|
return
|
|
}
|
|
|
|
if payload.Action != "" {
|
|
log.Printf("Received system command: %s", payload.Action)
|
|
a.reportDone(ctx, workflow.Id, "")
|
|
|
|
switch payload.Action {
|
|
case "restart":
|
|
log.Println("Restarting agent...")
|
|
os.Exit(0)
|
|
case "update":
|
|
log.Println("Updating agent...")
|
|
imageName := os.Getenv("AGENT_IMAGE")
|
|
if imageName == "" {
|
|
imageName = "stream-api-agent:latest"
|
|
}
|
|
|
|
if err := a.docker.SelfUpdate(context.Background(), imageName, a.agentID); err != nil {
|
|
log.Printf("Update failed: %v", err)
|
|
} else {
|
|
os.Exit(0)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
mdCtx := a.withToken(ctx)
|
|
if _, err := a.client.Init(mdCtx, &proto.InitRequest{Id: workflow.Id}); err != nil {
|
|
log.Printf("Failed to init job %s: %v", workflow.Id, err)
|
|
return
|
|
}
|
|
|
|
log.Printf("Running container with image: %s", payload.Image)
|
|
done := make(chan error, 1)
|
|
go a.extendLoop(jobCtx, workflow.Id)
|
|
go func() {
|
|
done <- a.docker.Run(jobCtx, payload.Image, payload.Commands, payload.Environment, func(line string) {
|
|
progress := -1.0
|
|
if val, ok := parseProgress(line); ok {
|
|
progress = val
|
|
}
|
|
|
|
entries := []*proto.LogEntry{{
|
|
StepUuid: workflow.Id,
|
|
Data: []byte(line),
|
|
Time: time.Now().Unix(),
|
|
Type: 1,
|
|
}}
|
|
|
|
if progress >= 0 {
|
|
entries = append(entries, &proto.LogEntry{
|
|
StepUuid: workflow.Id,
|
|
Time: time.Now().Unix(),
|
|
Type: 4,
|
|
Data: []byte(fmt.Sprintf("%f", progress)),
|
|
})
|
|
}
|
|
|
|
if _, err := a.client.Log(mdCtx, &proto.LogRequest{LogEntries: entries}); err != nil {
|
|
log.Printf("Failed to send log for job %s: %v", workflow.Id, err)
|
|
}
|
|
})
|
|
}()
|
|
|
|
var err error
|
|
select {
|
|
case err = <-done:
|
|
case <-jobCtx.Done():
|
|
if jobCtx.Err() == context.DeadlineExceeded {
|
|
err = fmt.Errorf("job timeout exceeded")
|
|
log.Printf("Job %s timed out", workflow.Id)
|
|
} else {
|
|
err = fmt.Errorf("job cancelled")
|
|
log.Printf("Job %s was cancelled", workflow.Id)
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
log.Printf("Job %s failed: %v", workflow.Id, err)
|
|
a.reportDone(ctx, workflow.Id, err.Error())
|
|
} else {
|
|
log.Printf("Job %s succeeded", workflow.Id)
|
|
a.reportDone(ctx, workflow.Id, "")
|
|
}
|
|
}
|
|
|
|
func (a *Agent) CancelJob(jobID string) bool {
|
|
if cancelFunc, ok := a.activeJobs.Load(jobID); ok {
|
|
log.Printf("Cancelling job %s", jobID)
|
|
cancelFunc.(context.CancelFunc)()
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (a *Agent) extendLoop(ctx context.Context, jobID string) {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
mdCtx := a.withToken(ctx)
|
|
if _, err := a.client.Extend(mdCtx, &proto.ExtendRequest{Id: jobID}); err != nil {
|
|
log.Printf("Failed to extend lease for job %s: %v", jobID, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *Agent) reportDone(_ context.Context, id string, errStr string) {
|
|
state := &proto.WorkflowState{Finished: time.Now().Unix()}
|
|
if errStr != "" {
|
|
state.Error = errStr
|
|
}
|
|
|
|
reportCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
mdCtx := a.withToken(reportCtx)
|
|
_, err := a.client.Done(mdCtx, &proto.DoneRequest{Id: id, State: state})
|
|
if err != nil {
|
|
log.Printf("Failed to report Done for job %s: %v", id, err)
|
|
}
|
|
}
|
|
|
|
func (a *Agent) unregister(ctx context.Context) error {
|
|
if a.token == "" {
|
|
return nil
|
|
}
|
|
|
|
mdCtx := a.withToken(ctx)
|
|
_, err := a.client.UnregisterAgent(mdCtx, &proto.Empty{})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to unregister agent: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
const (
|
|
AgentIDFile = "/data/agent_id"
|
|
HostnameFile = "/host_hostname"
|
|
)
|
|
|
|
type AgentIdentity struct {
|
|
ID string `json:"id"`
|
|
Fingerprint string `json:"fingerprint"`
|
|
}
|
|
|
|
func (a *Agent) getHostFingerprint() string {
|
|
data, err := os.ReadFile(HostnameFile)
|
|
if err == nil {
|
|
return strings.TrimSpace(string(data))
|
|
}
|
|
hostname, _ := os.Hostname()
|
|
return hostname
|
|
}
|
|
|
|
func (a *Agent) loadAgentID() (string, error) {
|
|
data, err := os.ReadFile(AgentIDFile)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
var identity AgentIdentity
|
|
if err := json.Unmarshal(data, &identity); err == nil {
|
|
currentFP := a.getHostFingerprint()
|
|
if identity.Fingerprint != "" && identity.Fingerprint != currentFP {
|
|
log.Printf("Environment changed (Hostname mismatch: saved=%s, current=%s). Resetting Agent ID.", identity.Fingerprint, currentFP)
|
|
return "", fmt.Errorf("environment changed")
|
|
}
|
|
return identity.ID, nil
|
|
}
|
|
|
|
id := strings.TrimSpace(string(data))
|
|
if id == "" {
|
|
return "", fmt.Errorf("empty ID")
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (a *Agent) saveAgentID(id string) error {
|
|
if err := os.MkdirAll("/data", 0755); err != nil {
|
|
return err
|
|
}
|
|
|
|
identity := AgentIdentity{ID: id, Fingerprint: a.getHostFingerprint()}
|
|
data, err := json.Marshal(identity)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return os.WriteFile(AgentIDFile, data, 0644)
|
|
}
|