Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: share credential #634

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions integration/cred_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,19 @@ func TestGPTScriptCredential(t *testing.T) {
require.NoError(t, err)
require.Contains(t, out, "CREDENTIAL")
}

// TestCredentialScopes makes sure that environment variables set by credential tools and shared credential tools
// are only available to the correct tools. See scripts/credscopes.gpt for more details.
func TestCredentialScopes(t *testing.T) {
out, err := RunScript("scripts/credscopes.gpt", "--sub-tool", "oneOne")
require.NoError(t, err)
require.Contains(t, out, "good")

out, err = RunScript("scripts/credscopes.gpt", "--sub-tool", "twoOne")
require.NoError(t, err)
require.Contains(t, out, "good")

out, err = RunScript("scripts/credscopes.gpt", "--sub-tool", "twoTwo")
require.NoError(t, err)
require.Contains(t, out, "good")
}
4 changes: 4 additions & 0 deletions integration/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ func GPTScriptExec(args ...string) (string, error) {
out, err := cmd.CombinedOutput()
return string(out), err
}

func RunScript(script string, options ...string) (string, error) {
return GPTScriptExec(append(options, "--quiet", script)...)
}
160 changes: 160 additions & 0 deletions integration/scripts/credscopes.gpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# This script sets up a chain of tools in a tree structure.
# The root is oneOne, with children twoOne and twoTwo, with children threeOne, threeTwo, and threeThree, with only
# threeTwo shared between them.
# Each tool should only have access to any credentials it defines and any credentials exported/shared by its
# immediate children (but not grandchildren).
# This script checks to make sure that this is working properly.
name: oneOne
tools: twoOne, twoTwo
cred: getcred with oneOne as var and 11 as val

#!python3

import os

oneOne = os.getenv('oneOne')
twoOne = os.getenv('twoOne')
twoTwo = os.getenv('twoTwo')
threeOne = os.getenv('threeOne')
threeTwo = os.getenv('threeTwo')
threeThree = os.getenv('threeThree')

if oneOne != '11':
print('error: oneOne is not 11')
exit(1)

if twoOne != '21':
print('error: twoOne is not 21')
exit(1)

if twoTwo != '22':
print('error: twoTwo is not 22')
exit(1)

if threeOne is not None:
print('error: threeOne is not None')
exit(1)

if threeTwo is not None:
print('error: threeTwo is not None')
exit(1)

if threeThree is not None:
print('error: threeThree is not None')
exit(1)

print('good')

---
name: twoOne
tools: threeOne, threeTwo
sharecred: getcred with twoOne as var and 21 as val

#!python3

import os

oneOne = os.getenv('oneOne')
twoOne = os.getenv('twoOne')
twoTwo = os.getenv('twoTwo')
threeOne = os.getenv('threeOne')
threeTwo = os.getenv('threeTwo')
threeThree = os.getenv('threeThree')

if oneOne is not None:
print('error: oneOne is not None')
exit(1)

if twoOne is not None:
print('error: twoOne is not None')
exit(1)

if twoTwo is not None:
print('error: twoTwo is not None')
exit(1)

if threeOne != '31':
print('error: threeOne is not 31')
exit(1)

if threeTwo != '32':
print('error: threeTwo is not 32')
exit(1)

if threeThree is not None:
print('error: threeThree is not None')
exit(1)

print('good')

---
name: twoTwo
tools: threeTwo, threeThree
sharecred: getcred with twoTwo as var and 22 as val

#!python3

import os

oneOne = os.getenv('oneOne')
twoOne = os.getenv('twoOne')
twoTwo = os.getenv('twoTwo')
threeOne = os.getenv('threeOne')
threeTwo = os.getenv('threeTwo')
threeThree = os.getenv('threeThree')

if oneOne is not None:
print('error: oneOne is not None')
exit(1)

if twoOne is not None:
print('error: twoOne is not None')
exit(1)

if twoTwo is not None:
print('error: twoTwo is not None')
exit(1)

if threeOne is not None:
print('error: threeOne is not None')
exit(1)

if threeTwo != '32':
print('error: threeTwo is not 32')
exit(1)

if threeThree != '33':
print('error: threeThree is not 33')
exit(1)

print('good')

---
name: threeOne
sharecred: getcred with threeOne as var and 31 as val

---
name: threeTwo
sharecred: getcred with threeTwo as var and 32 as val

---
name: threeThree
sharecred: getcred with threeThree as var and 33 as val

---
name: getcred

#!python3

import os
import json

var = os.getenv('var')
val = os.getenv('val')

output = {
"env": {
var: val
}
}
print(json.dumps(output))
2 changes: 2 additions & 0 deletions pkg/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
}
case "credentials", "creds", "credential", "cred":
tool.Parameters.Credentials = append(tool.Parameters.Credentials, value)
case "sharecredentials", "sharecreds", "sharecredential", "sharecred":
tool.Parameters.ExportCredentials = append(tool.Parameters.ExportCredentials, value)
default:
return false, nil
}
Expand Down
41 changes: 22 additions & 19 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,13 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
return nil, err
}

if len(callCtx.Tool.Credentials) > 0 {
credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup)
if err != nil {
return nil, err
}
if len(credTools) > 0 {
var err error
env, err = r.handleCredentials(callCtx, monitor, env)
env, err = r.handleCredentials(callCtx, monitor, env, credTools)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -552,9 +556,13 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
progress, progressClose := streamProgress(&callCtx, monitor)
defer progressClose()

if len(callCtx.Tool.Credentials) > 0 {
credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup)
if err != nil {
return nil, err
}
if len(credTools) > 0 {
var err error
env, err = r.handleCredentials(callCtx, monitor, env)
env, err = r.handleCredentials(callCtx, monitor, env, credTools)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -828,7 +836,7 @@ func getEventContent(content string, callCtx engine.Context) string {
return content
}

func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string) ([]string, error) {
func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string, credToolRefs []types.ToolReference) ([]string, error) {
// Since credential tools (usually) prompt the user, we want to only run one at a time.
r.credMutex.Lock()
defer r.credMutex.Unlock()
Expand All @@ -845,10 +853,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
}
}

for _, credToolName := range callCtx.Tool.Credentials {
toolName, credentialAlias, args, err := types.ParseCredentialArgs(credToolName, callCtx.Input)
for _, ref := range credToolRefs {
toolName, credentialAlias, args, err := types.ParseCredentialArgs(ref.Reference, callCtx.Input)
if err != nil {
return nil, fmt.Errorf("failed to parse credential tool %q: %w", credToolName, err)
return nil, fmt.Errorf("failed to parse credential tool %q: %w", ref.Reference, err)
}

credName := toolName
Expand Down Expand Up @@ -895,11 +903,6 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
// If the credential doesn't already exist in the store, run the credential tool in order to get the value,
// and save it in the store.
if !exists || c.IsExpired() {
credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName]
if !ok || len(credToolRefs) != 1 {
return nil, fmt.Errorf("failed to find ID for tool %s", credToolName)
}

// If the existing credential is expired, we need to provide it to the cred tool through the environment.
if exists && c.IsExpired() {
credJSON, err := json.Marshal(c)
Expand All @@ -914,22 +917,22 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
if args != nil {
inputBytes, err := json.Marshal(args)
if err != nil {
return nil, fmt.Errorf("failed to marshal args for tool %s: %w", credToolName, err)
return nil, fmt.Errorf("failed to marshal args for tool %s: %w", ref.Reference, err)
}
input = string(inputBytes)
}

res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, credToolRefs[0].ToolID, input, "", engine.CredentialToolCategory)
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, ref.ToolID, input, "", engine.CredentialToolCategory)
if err != nil {
return nil, fmt.Errorf("failed to run credential tool %s: %w", credToolName, err)
return nil, fmt.Errorf("failed to run credential tool %s: %w", ref.Reference, err)
}

if res.Result == nil {
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", credToolName)
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", ref.Reference)
}

if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", credToolName, err)
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", ref.Reference, err)
}
c.ToolName = credName
c.Type = credentials.CredentialTypeTool
Expand All @@ -943,7 +946,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
}

// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil) || credentialAlias != "" {
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
if isEmpty {
log.Warnf("Not saving empty credential for tool %s", toolName)
} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
Expand Down
24 changes: 24 additions & 0 deletions pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ type Parameters struct {
Export []string `json:"export,omitempty"`
Agents []string `json:"agents,omitempty"`
Credentials []string `json:"credentials,omitempty"`
ExportCredentials []string `json:"exportCredentials,omitempty"`
InputFilters []string `json:"inputFilters,omitempty"`
ExportInputFilters []string `json:"exportInputFilters,omitempty"`
OutputFilters []string `json:"outputFilters,omitempty"`
Expand All @@ -154,6 +155,7 @@ func (p Parameters) ToolRefNames() []string {
p.ExportContext,
p.Context,
p.Credentials,
p.ExportCredentials,
p.InputFilters,
p.ExportInputFilters,
p.OutputFilters,
Expand Down Expand Up @@ -466,6 +468,11 @@ func (t ToolDef) String() string {
_, _ = fmt.Fprintf(buf, "Credential: %s\n", cred)
}
}
if len(t.Parameters.ExportCredentials) > 0 {
for _, exportCred := range t.Parameters.ExportCredentials {
_, _ = fmt.Fprintf(buf, "Share Credential: %s\n", exportCred)
}
}
if t.Parameters.Chat {
_, _ = fmt.Fprintf(buf, "Chat: true\n")
}
Expand Down Expand Up @@ -675,6 +682,23 @@ func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference) ([]
return result.List()
}

func (t Tool) GetCredentialTools(prg Program, agentGroup []ToolReference) ([]ToolReference, error) {
result := toolRefSet{}

result.AddAll(t.GetToolRefsFromNames(t.Credentials))

toolRefs, err := t.getCompletionToolRefs(prg, agentGroup)
if err != nil {
return nil, err
}
for _, toolRef := range toolRefs {
referencedTool := prg.ToolSet[toolRef.ToolID]
result.AddAll(referencedTool.GetToolRefsFromNames(referencedTool.ExportCredentials))
}

return result.List()
}

func toolRefsToCompletionTools(completionTools []ToolReference, prg Program) (result []CompletionTool) {
toolNames := map[string]struct{}{}

Expand Down