Skip to content

Commit

Permalink
runway: return custom error
Browse files Browse the repository at this point in the history
  • Loading branch information
igolaizola committed Jul 18, 2024
1 parent c443af4 commit 7c2baee
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 22 deletions.
15 changes: 11 additions & 4 deletions pkg/runway/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,26 @@ type Client struct {
token string
expiration time.Time
teamID int
folder string
}

type Config struct {
Token string
Wait time.Duration
Debug bool
Proxy string
Token string
Wait time.Duration
Debug bool
Proxy string
Folder string
}

func New(cfg *Config) (*Client, error) {
wait := cfg.Wait
if wait == 0 {
wait = 1 * time.Second
}
folder := cfg.Folder
if folder == "" {
folder = "Generative Video"
}
// Parse the JWT
parser := jwt.Parser{}
t, _, err := parser.ParseUnverified(cfg.Token, jwt.MapClaims{})
Expand All @@ -65,6 +71,7 @@ func New(cfg *Config) (*Client, error) {
debug: cfg.Debug,
token: cfg.Token,
expiration: expiration,
folder: folder,
}, nil
}

Expand Down
72 changes: 54 additions & 18 deletions pkg/runway/runway.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package runway
import (
"context"
"crypto/md5"
"encoding/json"
"fmt"
"math/rand"
"os"
Expand Down Expand Up @@ -192,21 +193,32 @@ type gen3Options struct {
}

type taskResponse struct {
Task struct {
ID string `json:"id"`
Name string `json:"name"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
TaskType string `json:"taskType"`
Options any `json:"options"`
Status string `json:"status"`
ProgressText string `json:"progressText"`
ProgressRatio string `json:"progressRatio"`
PlaceInLine int `json:"placeInLine"`
EstimatedTimeToStartSeconds float64 `json:"estimatedTimeToStartSeconds"`
Artifacts []artifact `json:"artifacts"`
SharedAsset interface{} `json:"sharedAsset"`
} `json:"task"`
Task taskData `json:"task"`
}

type taskData struct {
ID string `json:"id"`
Name string `json:"name"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
TaskType string `json:"taskType"`
Options any `json:"options"`
Status string `json:"status"`
Error taskError `json:"error"`
ProgressText string `json:"progressText"`
ProgressRatio string `json:"progressRatio"`
PlaceInLine int `json:"placeInLine"`
EstimatedTimeToStartSeconds float64 `json:"estimatedTimeToStartSeconds"`
Artifacts []artifact `json:"artifacts"`
SharedAsset interface{} `json:"sharedAsset"`
}

type taskError struct {
ErrorMessage string `json:"errorMessage"`
Reason string `json:"reason"`
Message string `json:"message"`
ModerationCategory string `json:"moderation_category"`
TallyAsimov bool `json:"tally_asimov"`
}

type artifact struct {
Expand Down Expand Up @@ -257,6 +269,30 @@ type GenerateRequest struct {
ExploreMode bool
}

type Error struct {
data taskData
}

func (e *Error) Error() string {
return fmt.Sprintf("runway: task %s %q (%q, %q)", e.data.Status, e.data.Error.Message, e.data.Error.Reason, e.data.Error.ModerationCategory)
}

func (e *Error) Debug() string {
js, _ := json.Marshal(e.data)
return string(js)
}

func (e *Error) Temporary() bool {
switch e.data.Error.Reason {
case "SAFETY.INPUT.TEXT":
return false
case "":
return true
default:
return false
}
}

func (c *Client) Generate(ctx context.Context, cfg *GenerateRequest) (*Generation, error) {
// Load team ID
if err := c.loadTeamID(ctx); err != nil {
Expand Down Expand Up @@ -316,7 +352,7 @@ func (c *Client) Generate(ctx context.Context, cfg *GenerateRequest) (*Generatio
Height: height,
},
Name: name,
AssetGroupName: "Generative Video",
AssetGroupName: c.folder,
ExploreMode: cfg.ExploreMode,
},
AsTeamID: c.teamID,
Expand All @@ -339,7 +375,7 @@ func (c *Client) Generate(ctx context.Context, cfg *GenerateRequest) (*Generatio
EnhancePrompt: true,
Width: width,
Height: height,
AssetGroupName: "Generative Video",
AssetGroupName: c.folder,
},
AsTeamID: c.teamID,
}
Expand Down Expand Up @@ -370,7 +406,7 @@ func (c *Client) Generate(ctx context.Context, cfg *GenerateRequest) (*Generatio
case "PENDING", "RUNNING", "THROTTLED":
c.log("runway: task %s: %s", taskResp.Task.ID, taskResp.Task.ProgressRatio)
default:
return nil, fmt.Errorf("runway: task failed: %s", taskResp.Task.Status)
return nil, &Error{data: taskResp.Task}
}

select {
Expand Down

0 comments on commit 7c2baee

Please sign in to comment.