diff --git a/README.md b/README.md index 9c44258..ea38177 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ curl -X POST localhost:4000/api/agents/dev/my-skill/run \ ```bash skrun init my-agent +skrun init my-agent --provider google cd my-agent # Creates SKILL.md (instructions) + agent.yaml (config) ``` diff --git a/docs/cli.md b/docs/cli.md index e70fde2..264d0d0 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -7,6 +7,7 @@ Create a new Skrun agent. ```bash skrun init [dir] skrun init my-agent +skrun init my-agent --provider google skrun init --from-skill ./existing-skill ``` @@ -17,6 +18,7 @@ skrun init --from-skill ./existing-skill | `--force` | Overwrite existing files | | `--name ` | Agent name (non-interactive) | | `--description ` | Agent description (non-interactive) | +| `--provider ` | Provider with a default model: `anthropic`, `openai`, `google`, `mistral`, `groq` | | `--model ` | Model (non-interactive) | | `--namespace ` | Namespace (non-interactive) | diff --git a/packages/cli/src/commands/init.test.ts b/packages/cli/src/commands/init.test.ts new file mode 100644 index 0000000..9c29c5b --- /dev/null +++ b/packages/cli/src/commands/init.test.ts @@ -0,0 +1,39 @@ +import { describe, expect, it, vi } from "vitest"; +import * as prompts from "../utils/prompts.js"; +import { resolveInitModel } from "./init.js"; + +describe("resolveInitModel", () => { + it("uses the provider default model without prompting", async () => { + const askModelSpy = vi.spyOn(prompts, "askModel"); + + await expect(resolveInitModel({ provider: "google" })).resolves.toEqual({ + provider: "google", + name: "gemini-2.5-flash", + }); + expect(askModelSpy).not.toHaveBeenCalled(); + }); + + it("prefers an explicit model over the provider flag", async () => { + const askModelSpy = vi.spyOn(prompts, "askModel"); + + await expect( + resolveInitModel({ provider: "google", model: "openai/gpt-4o-mini" }), + ).resolves.toEqual({ + provider: "openai", + name: "gpt-4o-mini", + }); + expect(askModelSpy).not.toHaveBeenCalled(); + }); + + it("falls back to the interactive model prompt", async () => { + const askModelSpy = vi + .spyOn(prompts, "askModel") + .mockResolvedValue({ provider: "anthropic", name: "claude-sonnet-4-20250514" }); + + await expect(resolveInitModel({})).resolves.toEqual({ + provider: "anthropic", + name: "claude-sonnet-4-20250514", + }); + expect(askModelSpy).toHaveBeenCalledTimes(1); + }); +}); diff --git a/packages/cli/src/commands/init.ts b/packages/cli/src/commands/init.ts index 61dba7b..73e331c 100644 --- a/packages/cli/src/commands/init.ts +++ b/packages/cli/src/commands/init.ts @@ -1,9 +1,9 @@ import { existsSync, mkdirSync, writeFileSync } from "node:fs"; import { basename, join, resolve } from "node:path"; import { AgentConfigSchema, serializeAgentYaml } from "@skrun-dev/schema"; -import type { Command } from "commander"; +import { type Command, Option } from "commander"; import * as format from "../utils/format.js"; -import { askModel, askText } from "../utils/prompts.js"; +import { DEFAULT_MODELS_BY_PROVIDER, askModel, askText } from "../utils/prompts.js"; import { initFromSkill } from "./init-from-skill.js"; const SKILL_MD_TEMPLATE = (name: string, description: string) => `--- @@ -31,6 +31,12 @@ export function registerInitCommand(program: Command): void { .option("--force", "Overwrite existing files") .option("--name ", "Agent name (non-interactive)") .option("--description ", "Agent description (non-interactive)") + .addOption( + new Option( + "--provider ", + "Provider with a default model (non-interactive)", + ).choices(Object.keys(DEFAULT_MODELS_BY_PROVIDER)), + ) .option("--model ", "Model as provider/name (non-interactive)") .option("--namespace ", "Agent namespace (non-interactive)") .action(async (dir: string | undefined, opts) => { @@ -46,10 +52,32 @@ interface InitOptions { force?: boolean; name?: string; description?: string; + provider?: keyof typeof DEFAULT_MODELS_BY_PROVIDER; model?: string; namespace?: string; } +export async function resolveInitModel( + opts: Pick, +): Promise<{ provider: string; name: string }> { + if (opts.model) { + const parts = opts.model.split("/"); + return { + provider: parts[0], + name: parts.slice(1).join("/"), + }; + } + + if (opts.provider) { + return { + provider: opts.provider, + name: DEFAULT_MODELS_BY_PROVIDER[opts.provider], + }; + } + + return askModel(); +} + async function runInit(dir: string | undefined, opts: InitOptions): Promise { const targetDir = dir ? resolve(dir) : process.cwd(); const dirName = basename(targetDir); @@ -66,17 +94,7 @@ async function runInit(dir: string | undefined, opts: InitOptions): Promise