From 7c2baeeded0485267927a99d17834dfd2a18dd5c Mon Sep 17 00:00:00 2001 From: igolaizola <11333576+igolaizola@users.noreply.github.com> Date: Thu, 18 Jul 2024 00:33:54 +0200 Subject: [PATCH] runway: return custom error --- pkg/runway/client.go | 15 ++++++--- pkg/runway/runway.go | 72 +++++++++++++++++++++++++++++++++----------- 2 files changed, 65 insertions(+), 22 deletions(-) diff --git a/pkg/runway/client.go b/pkg/runway/client.go index 86de8c8..742b0cf 100644 --- a/pkg/runway/client.go +++ b/pkg/runway/client.go @@ -26,13 +26,15 @@ type Client struct { token string expiration time.Time teamID int + folder string } type Config struct { - Token string - Wait time.Duration - Debug bool - Proxy string + Token string + Wait time.Duration + Debug bool + Proxy string + Folder string } func New(cfg *Config) (*Client, error) { @@ -40,6 +42,10 @@ func New(cfg *Config) (*Client, error) { if wait == 0 { wait = 1 * time.Second } + folder := cfg.Folder + if folder == "" { + folder = "Generative Video" + } // Parse the JWT parser := jwt.Parser{} t, _, err := parser.ParseUnverified(cfg.Token, jwt.MapClaims{}) @@ -65,6 +71,7 @@ func New(cfg *Config) (*Client, error) { debug: cfg.Debug, token: cfg.Token, expiration: expiration, + folder: folder, }, nil } diff --git a/pkg/runway/runway.go b/pkg/runway/runway.go index b68df8d..d336e7f 100644 --- a/pkg/runway/runway.go +++ b/pkg/runway/runway.go @@ -3,6 +3,7 @@ package runway import ( "context" "crypto/md5" + "encoding/json" "fmt" "math/rand" "os" @@ -192,21 +193,32 @@ type gen3Options struct { } type taskResponse struct { - Task struct { - ID string `json:"id"` - Name string `json:"name"` - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - TaskType string `json:"taskType"` - Options any `json:"options"` - Status string `json:"status"` - ProgressText string `json:"progressText"` - ProgressRatio string `json:"progressRatio"` - PlaceInLine int `json:"placeInLine"` - EstimatedTimeToStartSeconds float64 `json:"estimatedTimeToStartSeconds"` - Artifacts []artifact `json:"artifacts"` - SharedAsset interface{} `json:"sharedAsset"` - } `json:"task"` + Task taskData `json:"task"` +} + +type taskData struct { + ID string `json:"id"` + Name string `json:"name"` + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + TaskType string `json:"taskType"` + Options any `json:"options"` + Status string `json:"status"` + Error taskError `json:"error"` + ProgressText string `json:"progressText"` + ProgressRatio string `json:"progressRatio"` + PlaceInLine int `json:"placeInLine"` + EstimatedTimeToStartSeconds float64 `json:"estimatedTimeToStartSeconds"` + Artifacts []artifact `json:"artifacts"` + SharedAsset interface{} `json:"sharedAsset"` +} + +type taskError struct { + ErrorMessage string `json:"errorMessage"` + Reason string `json:"reason"` + Message string `json:"message"` + ModerationCategory string `json:"moderation_category"` + TallyAsimov bool `json:"tally_asimov"` } type artifact struct { @@ -257,6 +269,30 @@ type GenerateRequest struct { ExploreMode bool } +type Error struct { + data taskData +} + +func (e *Error) Error() string { + return fmt.Sprintf("runway: task %s %q (%q, %q)", e.data.Status, e.data.Error.Message, e.data.Error.Reason, e.data.Error.ModerationCategory) +} + +func (e *Error) Debug() string { + js, _ := json.Marshal(e.data) + return string(js) +} + +func (e *Error) Temporary() bool { + switch e.data.Error.Reason { + case "SAFETY.INPUT.TEXT": + return false + case "": + return true + default: + return false + } +} + func (c *Client) Generate(ctx context.Context, cfg *GenerateRequest) (*Generation, error) { // Load team ID if err := c.loadTeamID(ctx); err != nil { @@ -316,7 +352,7 @@ func (c *Client) Generate(ctx context.Context, cfg *GenerateRequest) (*Generatio Height: height, }, Name: name, - AssetGroupName: "Generative Video", + AssetGroupName: c.folder, ExploreMode: cfg.ExploreMode, }, AsTeamID: c.teamID, @@ -339,7 +375,7 @@ func (c *Client) Generate(ctx context.Context, cfg *GenerateRequest) (*Generatio EnhancePrompt: true, Width: width, Height: height, - AssetGroupName: "Generative Video", + AssetGroupName: c.folder, }, AsTeamID: c.teamID, } @@ -370,7 +406,7 @@ func (c *Client) Generate(ctx context.Context, cfg *GenerateRequest) (*Generatio case "PENDING", "RUNNING", "THROTTLED": c.log("runway: task %s: %s", taskResp.Task.ID, taskResp.Task.ProgressRatio) default: - return nil, fmt.Errorf("runway: task failed: %s", taskResp.Task.Status) + return nil, &Error{data: taskResp.Task} } select {