Skip to content
Merged
Prev Previous commit
Next Next commit
Some more fixes for typing
Signed-off-by: trangevi <[email protected]>
  • Loading branch information
trangevi committed Oct 26, 2025
commit d2dbe24af93b1680c58865825ab2945b1403096c
69 changes: 59 additions & 10 deletions cli/azd/extensions/azure.foundry.ai.agents/internal/cmd/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,8 +701,7 @@ func (a *InitAction) downloadAgentYaml(
return nil, "", fmt.Errorf("marshaling agent manifest to YAML after parameter processing: %w", err)
}

agentDef := agentManifest.Template.(agent_yaml.AgentDefinition)
agentId := agentDef.Name
agentId := agentManifest.Name

// Use targetDir if provided or set to local file pointer, otherwise default to "src/{agentId}"
if targetDir == "" {
Expand All @@ -720,12 +719,15 @@ func (a *InitAction) downloadAgentYaml(
return nil, "", fmt.Errorf("saving file to %s: %w", filePath, err)
}

if isGitHubUrl && agentDef.Kind == agent_yaml.AgentKindHosted {
// For hosted agents, download the entire parent directory
fmt.Println("Downloading full directory for hosted agent")
err := downloadParentDirectory(ctx, urlInfo, targetDir, ghCli, console)
if err != nil {
return nil, "", fmt.Errorf("downloading parent directory: %w", err)
if isGitHubUrl {
// Check if the template is a HostedContainerAgent
if _, isHostedContainer := agentManifest.Template.(agent_yaml.HostedContainerAgent); isHostedContainer {
// For hosted agents, download the entire parent directory
fmt.Println("Downloading full directory for hosted agent")
err := downloadParentDirectory(ctx, urlInfo, targetDir, ghCli, console)
if err != nil {
return nil, "", fmt.Errorf("downloading parent directory: %w", err)
}
}
}

Expand All @@ -736,7 +738,31 @@ func (a *InitAction) downloadAgentYaml(

func (a *InitAction) addToProject(ctx context.Context, targetDir string, agentManifest *agent_yaml.AgentManifest) error {
var host string
agentDef := agentManifest.Template.(agent_yaml.AgentDefinition)

// Convert the template to bytes
templateBytes, err := json.Marshal(agentManifest.Template)
if err != nil {
return fmt.Errorf("failed to marshal agent template to JSON: %w", err)
}

// Convert the bytes to a dictionary
var templateDict map[string]interface{}
if err := json.Unmarshal(templateBytes, &templateDict); err != nil {
return fmt.Errorf("failed to unmarshal agent template from JSON: %w", err)
}

// Convert the dictionary to bytes
dictJsonBytes, err := json.Marshal(templateDict)
if err != nil {
return fmt.Errorf("failed to marshal templateDict to JSON: %w", err)
}

// Convert the bytes to an Agent Definition
var agentDef agent_yaml.AgentDefinition
if err := json.Unmarshal(dictJsonBytes, &agentDef); err != nil {
return fmt.Errorf("failed to unmarshal JSON to AgentDefinition: %w", err)
}

switch agentDef.Kind {
case "container":
host = "containerapp"
Expand Down Expand Up @@ -1224,7 +1250,30 @@ func downloadDirectoryContents(
// }

func (a *InitAction) updateEnvironment(ctx context.Context, agentManifest *agent_yaml.AgentManifest) error {
agentDef := agentManifest.Template.(agent_yaml.AgentDefinition)
// Convert the template to bytes
templateBytes, err := json.Marshal(agentManifest.Template)
if err != nil {
return fmt.Errorf("failed to marshal agent template to JSON: %w", err)
}

// Convert the bytes to a dictionary
var templateDict map[string]interface{}
if err := json.Unmarshal(templateBytes, &templateDict); err != nil {
return fmt.Errorf("failed to unmarshal agent template from JSON: %w", err)
}

// Convert the dictionary to bytes
dictJsonBytes, err := json.Marshal(templateDict)
if err != nil {
return fmt.Errorf("failed to marshal templateDict to JSON: %w", err)
}

// Convert the bytes to an Agent Definition
var agentDef agent_yaml.AgentDefinition
if err := json.Unmarshal(dictJsonBytes, &agentDef); err != nil {
return fmt.Errorf("failed to unmarshal JSON to AgentDefinition: %w", err)
}

fmt.Printf("Updating environment variables for agent kind: %s\n", agentDef.Kind)

// Get current environment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@ import (

// LoadAndValidateAgentManifest parses YAML content and validates it as an AgentManifest
// Returns the parsed manifest and any validation errors
func LoadAndValidateAgentManifest(yamlContent []byte) (*AgentManifest, error) {
func LoadAndValidateAgentManifest(manifestYamlContent []byte) (*AgentManifest, error) {
agentDef, err := ExtractAgentDefinition(manifestYamlContent)
if err != nil {
return nil, fmt.Errorf("YAML content does not conform to AgentManifest format: %w", err)
}

var manifest AgentManifest
if err := yaml.Unmarshal(yamlContent, &manifest); err != nil {
if err := yaml.Unmarshal(manifestYamlContent, &manifest); err != nil {
return nil, fmt.Errorf("YAML content does not conform to AgentManifest format: %w", err)
}
manifest.Template = agentDef

if err := ValidateAgentManifest(&manifest); err != nil {
return nil, err
Expand All @@ -24,79 +30,119 @@ func LoadAndValidateAgentManifest(yamlContent []byte) (*AgentManifest, error) {
return &manifest, nil
}

// Returns a specific agent definition based on the "kind" field in the template
func ExtractAgentDefinition(manifestYamlContent []byte) (any, error) {
var genericManifest map[string]interface{}
if err := yaml.Unmarshal(manifestYamlContent, &genericManifest); err != nil {
return nil, fmt.Errorf("YAML content is not valid: %w", err)
}

template := genericManifest["template"].(map[string]interface{})
templateBytes, _ := yaml.Marshal(template)

var agentDef AgentDefinition
if err := yaml.Unmarshal(templateBytes, &agentDef); err != nil {
return nil, fmt.Errorf("failed to unmarshal to AgentDefinition: %v\n", err)
}

switch agentDef.Kind {
case AgentKindPrompt:
var agent PromptAgent
if err := yaml.Unmarshal(templateBytes, &agent); err != nil {
return nil, fmt.Errorf("failed to unmarshal to PromptAgent: %v\n", err)
}

agent.AgentDefinition = agentDef
return agent, nil
case AgentKindHosted:
var agent HostedContainerAgent
if err := yaml.Unmarshal(templateBytes, &agent); err != nil {
return nil, fmt.Errorf("failed to unmarshal to HostedContainerAgent: %v\n", err)
}

agent.AgentDefinition = agentDef
return agent, nil
case AgentKindContainerApp, AgentKindYamlContainerApp:
var agent ContainerAgent
if err := yaml.Unmarshal(templateBytes, &agent); err != nil {
return nil, fmt.Errorf("failed to unmarshal to ContainerAgent: %v\n", err)
}

agent.AgentDefinition = agentDef
return agent, nil
}

return nil, fmt.Errorf("unrecognized agent kind: %s", agentDef.Kind)
}

// ValidateAgentManifest performs basic validation of an AgentManifest
// Returns an error if the manifest is invalid, nil if valid
func ValidateAgentManifest(manifest *AgentManifest) error {
var errors []string

// First, extract the kind from the template to determine the agent type
templateMap, ok := manifest.Template.(map[string]interface{})
if !ok {
errors = append(errors, "template must be a valid object")
templateBytes, _ := yaml.Marshal(manifest.Template)

var agentDef AgentDefinition
if err := yaml.Unmarshal(templateBytes, &agentDef); err != nil {
errors = append(errors, "failed to parse template to determine agent kind")
} else {
kindValue, hasKind := templateMap["kind"]
if !hasKind {
errors = append(errors, "template.kind is required")
// Validate the kind is supported
if !IsValidAgentKind(agentDef.Kind) {
validKinds := ValidAgentKinds()
validKindStrings := make([]string, len(validKinds))
for i, validKind := range validKinds {
validKindStrings[i] = string(validKind)
}
errors = append(errors, fmt.Sprintf("template.kind must be one of: %v, got '%s'", validKindStrings, agentDef.Kind))
} else {
kind, kindOk := kindValue.(string)
if !kindOk {
errors = append(errors, "template.kind must be a string")
} else {
// Validate the kind is supported
if !IsValidAgentKind(AgentKind(kind)) {
validKinds := ValidAgentKinds()
validKindStrings := make([]string, len(validKinds))
for i, validKind := range validKinds {
validKindStrings[i] = string(validKind)
switch AgentKind(agentDef.Kind) {
case AgentKindPrompt:
var agent PromptAgent
if err := yaml.Unmarshal(templateBytes, &agent); err == nil {
if agent.Name == "" {
errors = append(errors, "template.name is required")
}
if agent.Model.Id == "" {
errors = append(errors, "template.model.id is required")
}
errors = append(errors, fmt.Sprintf("template.kind must be one of: %v, got '%s'", validKindStrings, kind))
} else {
// Convert template to YAML bytes and unmarshal to specific type based on kind
templateBytes, err := yaml.Marshal(manifest.Template)
if err != nil {
errors = append(errors, "failed to process template structure")
} else {
switch AgentKind(kind) {
case AgentKindPrompt:
var agent PromptAgent
if err := yaml.Unmarshal(templateBytes, &agent); err == nil {
if agent.Name == "" {
errors = append(errors, "template.name is required")
}
if agent.Model.Id == "" {
errors = append(errors, "template.model.id is required")
}
}
case AgentKindHosted:
var agent HostedContainerAgent
if err := yaml.Unmarshal(templateBytes, &agent); err == nil {
if agent.Name == "" {
errors = append(errors, "template.name is required")
}
if len(agent.Models) == 0 {
errors = append(errors, "template.models is required and must not be empty")
}
}
case AgentKindContainerApp, AgentKindYamlContainerApp:
var agent ContainerAgent
if err := yaml.Unmarshal(templateBytes, &agent); err == nil {
if agent.Name == "" {
errors = append(errors, "template.name is required")
}
if len(agent.Models) == 0 {
errors = append(errors, "template.models is required and must not be empty")
}
}
case AgentKindWorkflow:
var agent WorkflowAgent
if err := yaml.Unmarshal(templateBytes, &agent); err == nil {
if agent.Name == "" {
errors = append(errors, "template.name is required")
}
// WorkflowAgent doesn't have models, so no model validation needed
}
}
errors = append(errors, fmt.Sprintf("Failed to unmarshal to PromptAgent: %v\n", err))
}
case AgentKindHosted:
var agent HostedContainerAgent
if err := yaml.Unmarshal(templateBytes, &agent); err == nil {
if agent.Name == "" {
errors = append(errors, "template.name is required")
}
// TODO: Do we need this?
// if len(agent.Models) == 0 {
// errors = append(errors, "template.models is required and must not be empty")
// }
} else {
errors = append(errors, fmt.Sprintf("Failed to unmarshal to HostedContainerAgent: %v\n", err))
}
case AgentKindContainerApp, AgentKindYamlContainerApp:
var agent ContainerAgent
if err := yaml.Unmarshal(templateBytes, &agent); err == nil {
if agent.Name == "" {
errors = append(errors, "template.name is required")
}
if len(agent.Models) == 0 {
errors = append(errors, "template.models is required and must not be empty")
}
} else {
errors = append(errors, fmt.Sprintf("Failed to unmarshal to ContainerAgent: %v\n", err))
}
case AgentKindWorkflow:
var agent WorkflowAgent
if err := yaml.Unmarshal(templateBytes, &agent); err == nil {
if agent.Name == "" {
errors = append(errors, "template.name is required")
}
// WorkflowAgent doesn't have models, so no model validation needed
} else {
errors = append(errors, fmt.Sprintf("Failed to unmarshal to WorkflowAgent: %v\n", err))
}
}
}
Expand Down
Loading
Loading