diff --git a/.changeset/chilled-jokes-teach.md b/.changeset/chilled-jokes-teach.md new file mode 100644 index 00000000..207ed3c7 --- /dev/null +++ b/.changeset/chilled-jokes-teach.md @@ -0,0 +1,5 @@ +--- +"@browserbasehq/stagehand": minor +--- + +Observe got a major upgrade. Now it will return a suggested playwright method with any necessary arguments for the generated candidate elements. It also includes a major speedup when using a11y tree processing for context. diff --git a/evals/evals.config.json b/evals/evals.config.json index 6906cb06..9e4b11d1 100644 --- a/evals/evals.config.json +++ b/evals/evals.config.json @@ -215,6 +215,30 @@ { "name": "extract_zillow", "categories": ["text_extract"] + }, + { + "name": "observe_github", + "categories": ["observe"] + }, + { + "name": "observe_vantechjournal", + "categories": ["observe"] + }, + { + "name": "observe_amazon_add_to_cart", + "categories": ["observe"] + }, + { + "name": "observe_simple_google_search", + "categories": ["observe"] + }, + { + "name": "observe_yc_startup", + "categories": ["observe"] + }, + { + "name": "observe_taxes", + "categories": ["observe"] } ] } diff --git a/evals/tasks/ionwave_observe.ts b/evals/tasks/ionwave_observe.ts index 15cc2d8d..b0433342 100644 --- a/evals/tasks/ionwave_observe.ts +++ b/evals/tasks/ionwave_observe.ts @@ -1,11 +1,7 @@ import { initStagehand } from "@/evals/initStagehand"; import { EvalFunction } from "@/types/evals"; -export const ionwave_observe: EvalFunction = async ({ - modelName, - logger, - useAccessibilityTree, -}) => { +export const ionwave_observe: EvalFunction = async ({ modelName, logger }) => { const { stagehand, initResponse } = await initStagehand({ modelName, logger, @@ -15,7 +11,7 @@ export const ionwave_observe: EvalFunction = async ({ await stagehand.page.goto("https://elpasotexas.ionwave.net/Login.aspx"); - const observations = await stagehand.page.observe({ useAccessibilityTree }); + const observations = await stagehand.page.observe({ onlyVisible: true }); if (observations.length === 0) { await stagehand.close(); diff --git a/evals/tasks/observe_amazon_add_to_cart.ts b/evals/tasks/observe_amazon_add_to_cart.ts new file mode 100644 index 00000000..8be49da1 --- /dev/null +++ b/evals/tasks/observe_amazon_add_to_cart.ts @@ -0,0 +1,75 @@ +import { EvalFunction } from "@/types/evals"; +import { initStagehand } from "@/evals/initStagehand"; +import { performPlaywrightMethod } from "@/lib/a11y/utils"; + +export const observe_amazon_add_to_cart: EvalFunction = async ({ + modelName, + logger, +}) => { + const { stagehand, initResponse } = await initStagehand({ + modelName, + logger, + }); + + const { debugUrl, sessionUrl } = initResponse; + + await stagehand.page.goto( + "https://www.amazon.com/Laptop-MacBook-Surface-Water-Resistant-Accessories/dp/B0D5M4H5CD", + ); + + await stagehand.page.waitForTimeout(5000); + + const observations1 = await stagehand.page.observe({ + instruction: "Find and click the 'Add to Cart' button", + onlyVisible: false, + returnAction: true, + }); + + console.log(observations1); + + // Example of using performPlaywrightMethod if you have the xpath + if (observations1.length > 0) { + const action1 = observations1[0]; + await performPlaywrightMethod( + stagehand.page, + stagehand.logger, + action1.method, + action1.arguments, + action1.selector.replace("xpath=", ""), + ); + } + + await stagehand.page.waitForTimeout(2000); + + const observations2 = await stagehand.page.observe({ + instruction: "Find and click the 'Proceed to checkout' button", + onlyVisible: false, + returnAction: true, + }); + + // Example of using performPlaywrightMethod if you have the xpath + if (observations2.length > 0) { + const action2 = observations2[0]; + await performPlaywrightMethod( + stagehand.page, + stagehand.logger, + action2.method, + action2.arguments, + action2.selector.replace("xpath=", ""), + ); + } + await stagehand.page.waitForTimeout(2000); + + const currentUrl = stagehand.page.url(); + const expectedUrlPrefix = "https://www.amazon.com/ap/signin"; + + await stagehand.close(); + + return { + _success: currentUrl.startsWith(expectedUrlPrefix), + currentUrl, + debugUrl, + sessionUrl, + logs: logger.getLogs(), + }; +}; diff --git a/evals/tasks/observe_github.ts b/evals/tasks/observe_github.ts new file mode 100644 index 00000000..1effed68 --- /dev/null +++ b/evals/tasks/observe_github.ts @@ -0,0 +1,92 @@ +import { initStagehand } from "@/evals/initStagehand"; +import { EvalFunction } from "@/types/evals"; + +export const observe_github: EvalFunction = async ({ modelName, logger }) => { + const { stagehand, initResponse } = await initStagehand({ + modelName, + logger, + }); + + const { debugUrl, sessionUrl } = initResponse; + + await stagehand.page.goto( + "https://github.com/browserbase/stagehand/tree/main/lib", + ); + + const observations = await stagehand.page.observe({ + instruction: "find the scrollable element that holds the repos file tree", + }); + + if (observations.length === 0) { + await stagehand.close(); + return { + _success: false, + observations, + debugUrl, + sessionUrl, + logs: logger.getLogs(), + }; + } + + const possibleLocators = [ + `#repos-file-tree > div.Box-sc-g0xbh4-0.jbQqON > div > div > div > nav > ul`, + `#repos-file-tree > div.Box-sc-g0xbh4-0.jbQqON > div > div > div > nav`, + ]; + + const possibleHandles = []; + for (const locatorStr of possibleLocators) { + const locator = stagehand.page.locator(locatorStr); + const handle = await locator.elementHandle(); + if (handle) { + possibleHandles.push({ locatorStr, handle }); + } + } + + let foundMatch = false; + let matchedLocator: string | null = null; + + for (const observation of observations) { + try { + const observationLocator = stagehand.page + .locator(observation.selector) + .first(); + const observationHandle = await observationLocator.elementHandle(); + if (!observationHandle) { + continue; + } + + for (const { locatorStr, handle: candidateHandle } of possibleHandles) { + const isSameNode = await observationHandle.evaluate( + (node, otherNode) => node === otherNode, + candidateHandle, + ); + if (isSameNode) { + foundMatch = true; + matchedLocator = locatorStr; + break; + } + } + + if (foundMatch) { + break; + } + } catch (error) { + console.warn( + `Failed to check observation with selector ${observation.selector}:`, + error.message, + ); + continue; + } + } + + await stagehand.close(); + + return { + _success: foundMatch, + matchedLocator, + observations, + debugUrl, + sessionUrl, + logs: logger.getLogs(), + }; +}; diff --git a/evals/tasks/observe_simple_google_search.ts b/evals/tasks/observe_simple_google_search.ts new file mode 100644 index 00000000..16eb38ff --- /dev/null +++ b/evals/tasks/observe_simple_google_search.ts @@ -0,0 +1,70 @@ +import { EvalFunction } from "@/types/evals"; +import { initStagehand } from "@/evals/initStagehand"; +import { performPlaywrightMethod } from "@/lib/a11y/utils"; + +export const observe_simple_google_search: EvalFunction = async ({ + modelName, + logger, +}) => { + const { stagehand, initResponse } = await initStagehand({ + modelName, + logger, + }); + + const { debugUrl, sessionUrl } = initResponse; + + await stagehand.page.goto("https://www.google.com"); + + // await stagehand.page.act({ + // action: 'Search for "OpenAI"', + // }); + const observation1 = await stagehand.page.observe({ + instruction: "Find the search bar and enter 'OpenAI'", + onlyVisible: false, + returnAction: true, + }); + console.log(observation1); + + if (observation1.length > 0) { + const action1 = observation1[0]; + await performPlaywrightMethod( + stagehand.page, + stagehand.logger, + action1.method, + action1.arguments, + action1.selector.replace("xpath=", ""), + ); + } + await stagehand.page.waitForTimeout(5000); + const observation2 = await stagehand.page.observe({ + instruction: "Click the search button in the suggestions dropdown", + onlyVisible: false, + returnAction: true, + }); + console.log(observation2); + + if (observation2.length > 0) { + const action2 = observation2[0]; + await performPlaywrightMethod( + stagehand.page, + stagehand.logger, + action2.method, + action2.arguments, + action2.selector.replace("xpath=", ""), + ); + } + await stagehand.page.waitForTimeout(5000); + + const expectedUrl = "https://www.google.com/search?q=OpenAI"; + const currentUrl = stagehand.page.url(); + + await stagehand.close(); + + return { + _success: currentUrl.startsWith(expectedUrl), + currentUrl, + debugUrl, + sessionUrl, + logs: logger.getLogs(), + }; +}; diff --git a/evals/tasks/observe_taxes.ts b/evals/tasks/observe_taxes.ts new file mode 100644 index 00000000..33a7a85e --- /dev/null +++ b/evals/tasks/observe_taxes.ts @@ -0,0 +1,76 @@ +import { EvalFunction } from "@/types/evals"; +import { initStagehand } from "@/evals/initStagehand"; + +export const observe_taxes: EvalFunction = async ({ modelName, logger }) => { + const { stagehand, initResponse } = await initStagehand({ + modelName, + logger, + }); + + const { debugUrl, sessionUrl } = initResponse; + + await stagehand.page.goto("https://file.1040.com/estimate/"); + + const observations = await stagehand.page.observe({ + instruction: "Find all the form elements under the 'Income' section", + }); + + if (observations.length === 0) { + await stagehand.close(); + return { + _success: false, + observations, + debugUrl, + sessionUrl, + logs: logger.getLogs(), + }; + } else if (observations.length < 13) { + await stagehand.close(); + return { + _success: false, + observations, + debugUrl, + sessionUrl, + logs: logger.getLogs(), + }; + } + + const expectedLocator = `#tpWages`; + + const expectedResult = await stagehand.page + .locator(expectedLocator) + .first() + .innerText(); + + let foundMatch = false; + for (const observation of observations) { + try { + const observationResult = await stagehand.page + .locator(observation.selector) + .first() + .innerText(); + + if (observationResult === expectedResult) { + foundMatch = true; + break; + } + } catch (error) { + console.warn( + `Failed to check observation with selector ${observation.selector}:`, + error.message, + ); + continue; + } + } + + await stagehand.close(); + + return { + _success: foundMatch, + expected: expectedResult, + observations, + debugUrl, + sessionUrl, + logs: logger.getLogs(), + }; +}; diff --git a/evals/tasks/observe_vantechjournal.ts b/evals/tasks/observe_vantechjournal.ts new file mode 100644 index 00000000..4ca7dbbf --- /dev/null +++ b/evals/tasks/observe_vantechjournal.ts @@ -0,0 +1,80 @@ +import { initStagehand } from "@/evals/initStagehand"; +import { EvalFunction } from "@/types/evals"; + +export const observe_vantechjournal: EvalFunction = async ({ + modelName, + logger, +}) => { + const { stagehand, initResponse } = await initStagehand({ + modelName, + logger, + }); + + const { debugUrl, sessionUrl } = initResponse; + + await stagehand.page.goto("https://vantechjournal.com/archive?page=8"); + await stagehand.page.waitForTimeout(1000); + + const observations = await stagehand.page.observe({ + instruction: "find the button that takes us to the 11th page", + }); + + if (observations.length === 0) { + await stagehand.close(); + return { + _success: false, + observations, + debugUrl, + sessionUrl, + logs: logger.getLogs(), + }; + } + + const expectedLocator = `a.rounded-lg:nth-child(8)`; + + const expectedResult = await stagehand.page.locator(expectedLocator); + + let foundMatch = false; + + for (const observation of observations) { + try { + const observationLocator = stagehand.page + .locator(observation.selector) + .first(); + const observationHandle = await observationLocator.elementHandle(); + const expectedHandle = await expectedResult.elementHandle(); + + if (!observationHandle || !expectedHandle) { + // Couldn’t get handles, skip + continue; + } + + const isSameNode = await observationHandle.evaluate( + (node, otherNode) => node === otherNode, + expectedHandle, + ); + + if (isSameNode) { + foundMatch = true; + break; + } + } catch (error) { + console.warn( + `Failed to check observation with selector ${observation.selector}:`, + error.message, + ); + continue; + } + } + + await stagehand.close(); + + return { + _success: foundMatch, + expected: expectedResult, + observations, + debugUrl, + sessionUrl, + logs: logger.getLogs(), + }; +}; diff --git a/evals/tasks/observe_yc_startup.ts b/evals/tasks/observe_yc_startup.ts new file mode 100644 index 00000000..913f398f --- /dev/null +++ b/evals/tasks/observe_yc_startup.ts @@ -0,0 +1,95 @@ +import { initStagehand } from "@/evals/initStagehand"; +import { EvalFunction } from "@/types/evals"; + +export const observe_yc_startup: EvalFunction = async ({ + modelName, + logger, +}) => { + const { stagehand, initResponse } = await initStagehand({ + modelName, + logger, + }); + + const { debugUrl, sessionUrl } = initResponse; + + await stagehand.page.goto("https://www.ycombinator.com/companies"); + await stagehand.page.waitForLoadState("networkidle"); + + const observations = await stagehand.page.observe({ + instruction: + "Find the container element that holds links to each of the startup companies. The companies each have a name, a description, and a link to their website.", + }); + + if (observations.length === 0) { + await stagehand.close(); + return { + _success: false, + observations, + debugUrl, + sessionUrl, + logs: logger.getLogs(), + }; + } + + const possibleLocators = [ + `div._section_1pgsr_163._results_1pgsr_343`, + `div._rightCol_1pgsr_592`, + ]; + + const possibleHandles = []; + for (const locatorStr of possibleLocators) { + const locator = stagehand.page.locator(locatorStr); + const handle = await locator.elementHandle(); + if (handle) { + possibleHandles.push({ locatorStr, handle }); + } + } + + let foundMatch = false; + let matchedLocator: string | null = null; + + for (const observation of observations) { + try { + const observationLocator = stagehand.page + .locator(observation.selector) + .first(); + const observationHandle = await observationLocator.elementHandle(); + if (!observationHandle) { + continue; + } + + for (const { locatorStr, handle: candidateHandle } of possibleHandles) { + const isSameNode = await observationHandle.evaluate( + (node, otherNode) => node === otherNode, + candidateHandle, + ); + if (isSameNode) { + foundMatch = true; + matchedLocator = locatorStr; + break; + } + } + + if (foundMatch) { + break; + } + } catch (error) { + console.warn( + `Failed to check observation with selector ${observation.selector}:`, + error.message, + ); + continue; + } + } + + await stagehand.close(); + + return { + _success: foundMatch, + matchedLocator, + observations, + debugUrl, + sessionUrl, + logs: logger.getLogs(), + }; +}; diff --git a/evals/tasks/panamcs.ts b/evals/tasks/panamcs.ts index 330af12b..afc98c8f 100644 --- a/evals/tasks/panamcs.ts +++ b/evals/tasks/panamcs.ts @@ -1,11 +1,7 @@ import { EvalFunction } from "@/types/evals"; import { initStagehand } from "@/evals/initStagehand"; -export const panamcs: EvalFunction = async ({ - modelName, - logger, - useAccessibilityTree, -}) => { +export const panamcs: EvalFunction = async ({ modelName, logger }) => { const { stagehand, initResponse } = await initStagehand({ modelName, logger, @@ -15,7 +11,7 @@ export const panamcs: EvalFunction = async ({ await stagehand.page.goto("https://panamcs.org/about/staff/"); - const observations = await stagehand.page.observe({ useAccessibilityTree }); + const observations = await stagehand.page.observe({ onlyVisible: true }); if (observations.length === 0) { await stagehand.close(); diff --git a/evals/tasks/shopify_homepage.ts b/evals/tasks/shopify_homepage.ts index 271f4c56..e846422e 100644 --- a/evals/tasks/shopify_homepage.ts +++ b/evals/tasks/shopify_homepage.ts @@ -1,11 +1,7 @@ import { EvalFunction } from "@/types/evals"; import { initStagehand } from "@/evals/initStagehand"; -export const shopify_homepage: EvalFunction = async ({ - modelName, - logger, - useAccessibilityTree, -}) => { +export const shopify_homepage: EvalFunction = async ({ modelName, logger }) => { const { stagehand, initResponse } = await initStagehand({ modelName, logger, @@ -15,7 +11,7 @@ export const shopify_homepage: EvalFunction = async ({ await stagehand.page.goto("https://www.shopify.com/"); - const observations = await stagehand.page.observe({ useAccessibilityTree }); + const observations = await stagehand.page.observe({ onlyVisible: true }); if (observations.length === 0) { await stagehand.close(); diff --git a/evals/tasks/vanta.ts b/evals/tasks/vanta.ts index 73a7906c..2959389c 100644 --- a/evals/tasks/vanta.ts +++ b/evals/tasks/vanta.ts @@ -1,11 +1,7 @@ import { EvalFunction } from "@/types/evals"; import { initStagehand } from "@/evals/initStagehand"; -export const vanta: EvalFunction = async ({ - modelName, - logger, - useAccessibilityTree, -}) => { +export const vanta: EvalFunction = async ({ modelName, logger }) => { const { stagehand, initResponse } = await initStagehand({ modelName, logger, @@ -16,7 +12,7 @@ export const vanta: EvalFunction = async ({ await stagehand.page.goto("https://www.vanta.com/"); await stagehand.page.act({ action: "close the cookies popup" }); - const observations = await stagehand.page.observe({ useAccessibilityTree }); + const observations = await stagehand.page.observe({ onlyVisible: true }); if (observations.length === 0) { await stagehand.close(); diff --git a/evals/tasks/vanta_h.ts b/evals/tasks/vanta_h.ts index 606659c4..eca69a0f 100644 --- a/evals/tasks/vanta_h.ts +++ b/evals/tasks/vanta_h.ts @@ -1,11 +1,7 @@ import { EvalFunction } from "@/types/evals"; import { initStagehand } from "@/evals/initStagehand"; -export const vanta_h: EvalFunction = async ({ - modelName, - logger, - useAccessibilityTree, -}) => { +export const vanta_h: EvalFunction = async ({ modelName, logger }) => { const { stagehand, initResponse } = await initStagehand({ modelName, logger, @@ -17,7 +13,7 @@ export const vanta_h: EvalFunction = async ({ const observations = await stagehand.page.observe({ instruction: "find the buy now button if it is available", - useAccessibilityTree, + onlyVisible: true, }); await stagehand.close(); diff --git a/lib/StagehandPage.ts b/lib/StagehandPage.ts index 07317a1b..e5239a4e 100644 --- a/lib/StagehandPage.ts +++ b/lib/StagehandPage.ts @@ -471,12 +471,13 @@ export class StagehandPage { : instructionOrOptions || {}; const { - instruction = "Find actions that can be performed on this page.", + instruction, modelName, modelClientOptions, useVision, // still destructure but will not pass it on domSettleTimeoutMs, - useAccessibilityTree = false, + returnAction = false, + onlyVisible = false, } = options; if (typeof useVision !== "undefined") { @@ -510,8 +511,8 @@ export class StagehandPage { value: llmClient.modelName, type: "string", }, - useAccessibilityTree: { - value: useAccessibilityTree ? "true" : "false", + onlyVisible: { + value: onlyVisible ? "true" : "false", type: "boolean", }, }, @@ -523,7 +524,8 @@ export class StagehandPage { llmClient, requestId, domSettleTimeoutMs, - useAccessibilityTree, + returnAction, + onlyVisible, }) .catch((e) => { this.stagehand.log({ diff --git a/lib/a11y/utils.ts b/lib/a11y/utils.ts index 4d710c6d..fcbb9e17 100644 --- a/lib/a11y/utils.ts +++ b/lib/a11y/utils.ts @@ -1,7 +1,11 @@ import { AccessibilityNode, TreeResult, AXNode } from "../../types/context"; import { StagehandPage } from "../StagehandPage"; import { LogLine } from "../../types/log"; -import { CDPSession } from "playwright"; +import { CDPSession, Page, Locator } from "playwright"; +import { + PlaywrightCommandMethodNotSupportedException, + PlaywrightCommandException, +} from "@/types/playwright"; // Parser function for str output export function formatSimplifiedTree( @@ -173,32 +177,33 @@ export async function getAccessibilityTree( // This function is wrapped into a string and sent as a CDP command // It is not meant to be actually executed here -function getNodePath(node: Element) { - const parts = []; - let current = node; - - while (current && current.parentNode) { - if (current.nodeType === Node.ELEMENT_NODE) { - let tagName = current.tagName.toLowerCase(); - const sameTagSiblings = Array.from(current.parentNode.children).filter( - (child) => child.tagName === current.tagName, - ); - - if (sameTagSiblings.length > 1) { - let index = 1; - for (const sibling of sameTagSiblings) { - if (sibling === current) break; - index++; - } - tagName += "[" + index + "]"; +function getNodePath(el: Element) { + if (!el || el.nodeType !== Node.ELEMENT_NODE) return ""; + const pathSegments = []; + let current = el; + while (current && current.nodeType === Node.ELEMENT_NODE) { + const tagName = current.nodeName.toLowerCase(); + let index = 1; + let sibling = current.previousSibling; + while (sibling) { + if ( + sibling.nodeType === Node.ELEMENT_NODE && + sibling.nodeName.toLowerCase() === tagName + ) { + index++; } - - parts.unshift(tagName); + sibling = sibling.previousSibling; } + const segment = index > 1 ? tagName + "[" + index + "]" : tagName; + pathSegments.unshift(segment); current = current.parentNode as Element; + if (!current || !current.parentNode) break; + if (current.nodeName.toLowerCase() === "html") { + pathSegments.unshift("html"); + break; + } } - - return "/" + parts.join("/"); + return "/" + pathSegments.join("/"); } const functionString = getNodePath.toString(); @@ -218,3 +223,319 @@ export async function getXPathByResolvedObjectId( return result.value || ""; } + +export async function performPlaywrightMethod( + stagehandPage: Page, + logger: (logLine: LogLine) => void, + method: string, + args: unknown[], + xpath: string, + // domSettleTimeoutMs?: number, +) { + const locator = stagehandPage.locator(`xpath=${xpath}`).first(); + const initialUrl = stagehandPage.url(); + + logger({ + category: "action", + message: "performing playwright method", + level: 2, + auxiliary: { + xpath: { + value: xpath, + type: "string", + }, + method: { + value: method, + type: "string", + }, + }, + }); + + if (method === "scrollIntoView") { + logger({ + category: "action", + message: "scrolling element into view", + level: 2, + auxiliary: { + xpath: { + value: xpath, + type: "string", + }, + }, + }); + try { + await locator + .evaluate((element: HTMLElement) => { + element.scrollIntoView({ behavior: "smooth", block: "center" }); + }) + .catch((e: Error) => { + logger({ + category: "action", + message: "error scrolling element into view", + level: 1, + auxiliary: { + error: { + value: e.message, + type: "string", + }, + trace: { + value: e.stack, + type: "string", + }, + xpath: { + value: xpath, + type: "string", + }, + }, + }); + }); + } catch (e) { + logger({ + category: "action", + message: "error scrolling element into view", + level: 1, + auxiliary: { + error: { + value: e.message, + type: "string", + }, + trace: { + value: e.stack, + type: "string", + }, + xpath: { + value: xpath, + type: "string", + }, + }, + }); + + throw new PlaywrightCommandException(e.message); + } + } else if (method === "fill" || method === "type") { + try { + await locator.fill(""); + await locator.click(); + const text = args[0]?.toString(); + for (const char of text) { + await stagehandPage.keyboard.type(char, { + delay: Math.random() * 50 + 25, + }); + } + } catch (e) { + logger({ + category: "action", + message: "error filling element", + level: 1, + auxiliary: { + error: { + value: e.message, + type: "string", + }, + trace: { + value: e.stack, + type: "string", + }, + xpath: { + value: xpath, + type: "string", + }, + }, + }); + + throw new PlaywrightCommandException(e.message); + } + } else if (method === "press") { + try { + const key = args[0]?.toString(); + await stagehandPage.keyboard.press(key); + } catch (e) { + logger({ + category: "action", + message: "error pressing key", + level: 1, + auxiliary: { + error: { + value: e.message, + type: "string", + }, + trace: { + value: e.stack, + type: "string", + }, + key: { + value: args[0]?.toString() ?? "unknown", + type: "string", + }, + }, + }); + + throw new PlaywrightCommandException(e.message); + } + } else if (typeof locator[method as keyof typeof locator] === "function") { + // Log current URL before action + logger({ + category: "action", + message: "page URL before action", + level: 2, + auxiliary: { + url: { + value: stagehandPage.url(), + type: "string", + }, + }, + }); + + // Perform the action + try { + await ( + locator[method as keyof Locator] as unknown as ( + ...args: string[] + ) => Promise + )(...args.map((arg) => arg?.toString() || "")); + } catch (e) { + logger({ + category: "action", + message: "error performing method", + level: 1, + auxiliary: { + error: { + value: e.message, + type: "string", + }, + trace: { + value: e.stack, + type: "string", + }, + xpath: { + value: xpath, + type: "string", + }, + method: { + value: method, + type: "string", + }, + args: { + value: JSON.stringify(args), + type: "object", + }, + }, + }); + + throw new PlaywrightCommandException(e.message); + } + + // Handle navigation if a new page is opened + if (method === "click") { + logger({ + category: "action", + message: "clicking element, checking for page navigation", + level: 1, + auxiliary: { + xpath: { + value: xpath, + type: "string", + }, + }, + }); + + const newOpenedTab = await Promise.race([ + new Promise((resolve) => { + Promise.resolve(stagehandPage.context()).then((context) => { + context.once("page", (page: Page) => resolve(page)); + setTimeout(() => resolve(null), 1_500); + }); + }), + ]); + + logger({ + category: "action", + message: "clicked element", + level: 1, + auxiliary: { + newOpenedTab: { + value: newOpenedTab ? "opened a new tab" : "no new tabs opened", + type: "string", + }, + }, + }); + + if (newOpenedTab) { + logger({ + category: "action", + message: "new page detected (new tab) with URL", + level: 1, + auxiliary: { + url: { + value: newOpenedTab.url(), + type: "string", + }, + }, + }); + await newOpenedTab.close(); + await stagehandPage.goto(newOpenedTab.url()); + await stagehandPage.waitForLoadState("domcontentloaded"); + // await stagehandPage._waitForSettledDom(domSettleTimeoutMs); + } + + await Promise.race([ + stagehandPage.waitForLoadState("networkidle"), + new Promise((resolve) => setTimeout(resolve, 5_000)), + ]).catch((e) => { + logger({ + category: "action", + message: "network idle timeout hit", + level: 1, + auxiliary: { + trace: { + value: e.stack, + type: "string", + }, + message: { + value: e.message, + type: "string", + }, + }, + }); + }); + + logger({ + category: "action", + message: "finished waiting for (possible) page navigation", + level: 1, + }); + + if (stagehandPage.url() !== initialUrl) { + logger({ + category: "action", + message: "new page detected with URL", + level: 1, + auxiliary: { + url: { + value: stagehandPage.url(), + type: "string", + }, + }, + }); + } + } + } else { + logger({ + category: "action", + message: "chosen method is invalid", + level: 1, + auxiliary: { + method: { + value: method, + type: "string", + }, + }, + }); + + throw new PlaywrightCommandMethodNotSupportedException( + `Method ${method} not supported`, + ); + } + + // await stagehandPage._waitForSettledDom(domSettleTimeoutMs); +} diff --git a/lib/handlers/observeHandler.ts b/lib/handlers/observeHandler.ts index 9267edbb..89d09214 100644 --- a/lib/handlers/observeHandler.ts +++ b/lib/handlers/observeHandler.ts @@ -53,13 +53,15 @@ export class StagehandObserveHandler { instruction, llmClient, requestId, - useAccessibilityTree = false, + returnAction, + onlyVisible, }: { instruction: string; llmClient: LLMClient; requestId: string; domSettleTimeoutMs?: number; - useAccessibilityTree?: boolean; + returnAction?: boolean; + onlyVisible?: boolean; }) { if (!instruction) { instruction = `Find elements that can be used for any future actions in the page. These may be navigation links, related pages, section/subsection links, buttons, or other interactive elements. Be comprehensive: if there are multiple elements that may be relevant for future actions, return all of them.`; @@ -76,65 +78,23 @@ export class StagehandObserveHandler { }, }); - let outputString: string; let selectorMap: Record = {}; - const backendNodeIdMap: Record = {}; - - await this.stagehandPage.startDomDebug(); - await this.stagehandPage.enableCDP("DOM"); - - const evalResult = await this.stagehand.page.evaluate(() => { - return window.processAllOfDom().then((result) => result); - }); - - // For each element in the selector map, get its backendNodeId - for (const [index, xpaths] of Object.entries(evalResult.selectorMap)) { - try { - // Use the first xpath to find the element - const xpath = xpaths[0]; - const { result } = await this.stagehandPage.sendCDP<{ - result: { objectId: string }; - }>("Runtime.evaluate", { - expression: `document.evaluate('${xpath}', document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue`, - returnByValue: false, - }); - - if (result.objectId) { - // Get the node details using CDP - const { node } = await this.stagehandPage.sendCDP<{ - node: { backendNodeId: number }; - }>("DOM.describeNode", { - objectId: result.objectId, - depth: -1, - pierce: true, - }); - - if (node.backendNodeId) { - backendNodeIdMap[index] = node.backendNodeId; - } - } - } catch (error) { - console.warn( - `Failed to get backendNodeId for element ${index}:`, - error, - ); - continue; - } - } - - await this.stagehandPage.disableCDP("DOM"); - ({ outputString, selectorMap } = evalResult); - + let outputString: string; + const useAccessibilityTree = !onlyVisible; if (useAccessibilityTree) { + await this.stagehandPage._waitForSettledDom(); const tree = await getAccessibilityTree(this.stagehandPage, this.logger); - this.logger({ category: "observation", message: "Getting accessibility tree data", level: 1, }); - outputString = tree.simplified; + } else { + const evalResult = await this.stagehand.page.evaluate(() => { + return window.processAllOfDom().then((result) => result); + }); + ({ outputString, selectorMap } = evalResult); } // No screenshot or vision-based annotation is performed @@ -146,47 +106,38 @@ export class StagehandObserveHandler { userProvidedInstructions: this.userProvidedInstructions, logger: this.logger, isUsingAccessibilityTree: useAccessibilityTree, + returnAction, }); const elementsWithSelectors = await Promise.all( observationResponse.elements.map(async (element) => { const { elementId, ...rest } = element; if (useAccessibilityTree) { - const index = Object.entries(backendNodeIdMap).find( - ([, value]) => value === elementId, - )?.[0]; - if (!index || !selectorMap[index]?.[0]) { - // Generate xpath for the given element if not found in selectorMap - const { object } = await this.stagehandPage.sendCDP<{ - object: { objectId: string }; - }>("DOM.resolveNode", { - backendNodeId: elementId, - }); - const xpath = await getXPathByResolvedObjectId( - await this.stagehandPage.getCDPClient(), - object.objectId, - ); - return { - ...rest, - selector: xpath, - backendNodeId: elementId, - }; - } + // Generate xpath for the given element if not found in selectorMap + const { object } = await this.stagehandPage.sendCDP<{ + object: { objectId: string }; + }>("DOM.resolveNode", { + backendNodeId: elementId, + }); + const xpath = await getXPathByResolvedObjectId( + await this.stagehandPage.getCDPClient(), + object.objectId, + ); return { ...rest, - selector: `xpath=${selectorMap[index][0]}`, - backendNodeId: elementId, + selector: `xpath=${xpath}`, + // Provisioning or future use if we want to use direct CDP + // backendNodeId: elementId, }; } return { ...rest, selector: `xpath=${selectorMap[elementId][0]}`, - backendNodeId: backendNodeIdMap[elementId], + // backendNodeId: backendNodeIdMap[elementId], }; }), ); - await this.stagehandPage.cleanupDomDebug(); this.logger({ diff --git a/lib/inference.ts b/lib/inference.ts index c67e2ae1..ff0043a6 100644 --- a/lib/inference.ts +++ b/lib/inference.ts @@ -271,6 +271,7 @@ export async function observe({ isUsingAccessibilityTree, userProvidedInstructions, logger, + returnAction = false, }: { instruction: string; domElements: string; @@ -279,6 +280,7 @@ export async function observe({ userProvidedInstructions?: string; logger: (message: LogLine) => void; isUsingAccessibilityTree?: boolean; + returnAction?: boolean; }) { const observeSchema = z.object({ elements: z @@ -292,6 +294,22 @@ export async function observe({ ? "a description of the accessible element and its purpose" : "a description of the element and what it is relevant for", ), + ...(returnAction + ? { + method: z + .string() + .describe( + "the candidate method/action to interact with the element. Select one of the available Playwright interaction methods.", + ), + arguments: z.array( + z + .string() + .describe( + "the arguments to pass to the method. For example, for a click, the arguments are empty, but for a fill, the arguments are the value to fill in.", + ), + ), + } + : {}), }), ) .describe( @@ -331,10 +349,20 @@ export async function observe({ }); const parsedResponse = { elements: - observationResponse.elements?.map((el) => ({ - elementId: Number(el.elementId), - description: String(el.description), - })) ?? [], + observationResponse.elements?.map((el) => { + const base = { + elementId: Number(el.elementId), + description: String(el.description), + }; + + return returnAction + ? { + ...base, + method: String(el.method), + arguments: el.arguments, + } + : base; + }) ?? [], } satisfies { elements: { elementId: number; description: string }[] }; return parsedResponse; diff --git a/types/stagehand.ts b/types/stagehand.ts index 54a03e2f..a088638e 100644 --- a/types/stagehand.ts +++ b/types/stagehand.ts @@ -89,10 +89,15 @@ export interface ObserveOptions { /** @deprecated Vision is not supported in this version of Stagehand. */ useVision?: boolean; domSettleTimeoutMs?: number; - useAccessibilityTree?: boolean; + returnAction?: boolean; + onlyVisible?: boolean; } export interface ObserveResult { selector: string; description: string; + backendNodeId?: number; + //TODO: review name + method?: string; + arguments?: string[]; }