diff --git a/js/dev/server.ts b/js/dev/server.ts index aee357ad5..4e4daaf4d 100644 --- a/js/dev/server.ts +++ b/js/dev/server.ts @@ -117,7 +117,7 @@ export function runDevServer( evalDefs[name] = { parameters, - scores: evaluator.scores.map((score, idx) => ({ + scores: (evaluator.scores ?? []).map((score, idx) => ({ name: scorerName(score, idx), })), }; @@ -209,7 +209,7 @@ export function runDevServer( { ...evaluator, data: evalData.data, - scores: evaluator.scores.concat( + scores: (evaluator.scores ?? []).concat( scores?.map((score) => makeScorer( state, diff --git a/js/src/cli/functions/infer-source.ts b/js/src/cli/functions/infer-source.ts index 179b5c50d..0a759b422 100644 --- a/js/src/cli/functions/infer-source.ts +++ b/js/src/cli/functions/infer-source.ts @@ -85,7 +85,7 @@ export async function findCodeDefinition({ fn = location.position.type === "task" ? evaluator.task - : evaluator.scores[location.position.index]; + : (evaluator.scores ?? [])[location.position.index]; } } else if (location.type === "function") { fn = outFileModule.functions[location.index].handler; diff --git a/js/src/cli/functions/upload.ts b/js/src/cli/functions/upload.ts index ce67f44d9..1745b5329 100644 --- a/js/src/cli/functions/upload.ts +++ b/js/src/cli/functions/upload.ts @@ -180,23 +180,25 @@ export async function uploadHandleBundles({ function_type: "task", origin, }, - ...evaluator.evaluator.scores.map((score, i): BundledFunctionSpec => { - const name = scorerName(score, i); - return { - ...baseInfo, - // There is a very small chance that someone names a function with the same convention, but - // let's assume it's low enough that it doesn't matter. - ...formatNameAndSlug(["eval", namePrefix, "scorer", name]), - description: `Score ${name} for eval ${namePrefix}`, - location: { - type: "experiment", - eval_name: evaluator.evaluator.evalName, - position: { type: "scorer", index: i }, - }, - function_type: "scorer", - origin, - }; - }), + ...(evaluator.evaluator.scores ?? []).map( + (score, i): BundledFunctionSpec => { + const name = scorerName(score, i); + return { + ...baseInfo, + // There is a very small chance that someone names a function with the same convention, but + // let's assume it's low enough that it doesn't matter. + ...formatNameAndSlug(["eval", namePrefix, "scorer", name]), + description: `Score ${name} for eval ${namePrefix}`, + location: { + type: "experiment", + eval_name: evaluator.evaluator.evalName, + position: { type: "scorer", index: i }, + }, + function_type: "scorer", + origin, + }; + }, + ), ]; bundleSpecs.push(...fileSpecs); @@ -219,7 +221,7 @@ export async function uploadHandleBundles({ serializeRemoteEvalParametersContainer(resolvedParameters), } : {}), - scores: evaluator.evaluator.scores.map((score, i) => ({ + scores: (evaluator.evaluator.scores ?? []).map((score, i) => ({ name: scorerName(score, i), })), }; diff --git a/js/src/exports.ts b/js/src/exports.ts index 6eb2b5c1f..352d5a80e 100644 --- a/js/src/exports.ts +++ b/js/src/exports.ts @@ -185,6 +185,7 @@ export type { EvalResult, EvalScorerArgs, EvalScorer, + EvalClassifier, EvaluatorDef, EvaluatorFile, ReporterBody, diff --git a/js/src/framework.test.ts b/js/src/framework.test.ts index bd9381342..45989fcdf 100644 --- a/js/src/framework.test.ts +++ b/js/src/framework.test.ts @@ -179,7 +179,6 @@ describe("runEvaluator", () => { new NoopProgressReporter(), [], undefined, - true, ); expect(out.results.every((r) => Object.keys(r.scores).length === 0)).toBe( @@ -207,7 +206,6 @@ describe("runEvaluator", () => { new NoopProgressReporter(), [], undefined, - true, ); expect( @@ -237,7 +235,6 @@ describe("runEvaluator", () => { new NoopProgressReporter(), [], undefined, - true, ); expect( @@ -271,7 +268,6 @@ describe("runEvaluator", () => { new NoopProgressReporter(), [], undefined, - true, ); expect( @@ -297,7 +293,6 @@ describe("runEvaluator", () => { new NoopProgressReporter(), [], undefined, - true, ); expect( @@ -477,7 +472,7 @@ test("trialIndex is passed to task", async () => { // All results should be correct results.forEach((result) => { expect(result.input).toBe(1); - expect(result.expected).toBe(2); + expect("expected" in result ? result.expected : undefined).toBe(2); expect(result.output).toBe(2); expect(result.error).toBeUndefined(); }); @@ -575,9 +570,8 @@ test("Eval with noSendLogs: true runs locally without creating experiment", asyn test("Eval with returnResults: false produces empty results but valid summary", async () => { const result = await Eval( - "test-no-results", + "test-no-results-project", { - projectName: "test-no-results-project", data: [ { input: "hello", expected: "hello world" }, { input: "test", expected: "test world" }, @@ -615,9 +609,8 @@ test("Eval with returnResults: false produces empty results but valid summary", test("Eval with returnResults: true collects all results", async () => { const result = await Eval( - "test-with-results", + "test-with-results-project", { - projectName: "test-with-results-project", data: [ { input: "hello", expected: "hello world" }, { input: "test", expected: "test world" }, @@ -668,7 +661,7 @@ test("tags can be appended and logged to root span", async () => { evalName: "js-tags-append", data: [{ input: "hello", expected: "hello world", tags: initialTags }], task: (input, hooks) => { - for (const t of appendedTags) hooks.tags.push(t); + for (const t of appendedTags) hooks.tags!.push(t); return input; }, scores: [() => ({ name: "simple_scorer", score: 0.8 })], @@ -825,7 +818,7 @@ test("scorer spans have purpose='scorer' attribute", async () => { data: [{ input: "hello", expected: "hello" }], task: async (input: string) => input, scores: [ - (args: { input: string; output: string; expected: string }) => ({ + (args: { output: string; expected?: string }) => ({ name: "simple_scorer", score: args.output === args.expected ? 1 : 0, }), @@ -972,11 +965,12 @@ describe("framework2 metadata support", () => { options: { model: "gpt-4" }, }, [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any { name: "test-prompt", slug: "test-prompt", metadata, - }, + } as any, ); const mockProjectMap = { @@ -1001,10 +995,8 @@ describe("framework2 metadata support", () => { options: { model: "gpt-4" }, }, [], - { - name: "test-prompt", - slug: "test-prompt", - }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + { name: "test-prompt", slug: "test-prompt" } as any, ); const mockProjectMap = { @@ -1027,11 +1019,12 @@ describe("framework2 metadata support", () => { options: { model: "gpt-4" }, }, [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any { name: "test-prompt", slug: "test-prompt", environments: ["production"], - }, + } as any, ); const mockProjectMap = { @@ -1054,11 +1047,12 @@ describe("framework2 metadata support", () => { options: { model: "gpt-4" }, }, [], + // eslint-disable-next-line @typescript-eslint/no-explicit-any { name: "test-prompt", slug: "test-prompt", environments: ["staging", "production"], - }, + } as any, ); const mockProjectMap = { @@ -1084,10 +1078,8 @@ describe("framework2 metadata support", () => { options: { model: "gpt-4" }, }, [], - { - name: "test-prompt", - slug: "test-prompt", - }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + { name: "test-prompt", slug: "test-prompt" } as any, ); const mockProjectMap = { @@ -1130,11 +1122,8 @@ describe("framework2 metadata support", () => { options: { model: "gpt-4" }, }, [], - { - name: "test-prompt", - slug: "test-prompt", - tags, - }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + { name: "test-prompt", slug: "test-prompt", tags } as any, ); const mockProjectMap = { @@ -1159,10 +1148,8 @@ describe("framework2 metadata support", () => { options: { model: "gpt-4" }, }, [], - { - name: "test-prompt", - slug: "test-prompt", - }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + { name: "test-prompt", slug: "test-prompt" } as any, ); const mockProjectMap = { @@ -1504,3 +1491,126 @@ test("Eval with enableCache: true (default) uses span cache", async () => { expect(startSpy).toHaveBeenCalled(); expect(stopSpy).toHaveBeenCalled(); }); + +test("classifier-only evaluator populates classifications field", async () => { + const result = await Eval( + "test-classifier-only", + { + data: [{ input: "hello", expected: "greeting" }], + task: (input) => input, + classifiers: [ + () => ({ + name: "category", + id: "greeting", + label: "Greeting", + metadata: { source: "unit-test" }, + }), + ], + }, + { noSendLogs: true, returnResults: true }, + ); + + expect(result.results).toHaveLength(1); + const r = result.results[0]; + expect(r.classifications?.category).toEqual([ + { + id: "greeting", + label: "Greeting", + metadata: { source: "unit-test" }, + }, + ]); +}); + +test("scorer-only evaluator populates scores field", async () => { + const result = await Eval( + "test-scorer-only", + { + data: [{ input: "hello", expected: "hello" }], + task: (input) => input, + scores: [ + (args) => ({ + name: "exact_match", + score: args.output === args.expected ? 1 : 0, + }), + ], + }, + { noSendLogs: true, returnResults: true }, + ); + + expect(result.results).toHaveLength(1); + expect(result.results[0].scores.exact_match).toBe(1); + expect(result.results[0].classifications).toBeUndefined(); +}); + +test("multiple classifiers returning the same name append items correctly", async () => { + const result = await Eval( + "test-classifier-append", + { + data: [{ input: "hello" }], + task: (input) => input, + classifiers: [ + () => [ + { name: "category", id: "greeting", label: "Greeting" }, + { name: "category", id: "informal", label: "Informal" }, + ], + ], + }, + { noSendLogs: true, returnResults: true }, + ); + + expect(result.results).toHaveLength(1); + expect(result.results[0].classifications?.category).toHaveLength(2); + expect(result.results[0].classifications?.category[0]).toEqual({ + id: "greeting", + label: "Greeting", + }); + expect(result.results[0].classifications?.category[1]).toEqual({ + id: "informal", + label: "Informal", + }); +}); + +test("mixed evaluator populates both scores and classifications", async () => { + const result = await Eval( + "test-score-and-classify", + { + data: [{ input: "hello", expected: "hello" }], + task: (input) => input, + scores: [ + (args) => ({ + name: "exact_match", + score: args.output === args.expected ? 1 : 0, + }), + ], + classifiers: [ + () => ({ name: "category", id: "greeting", label: "Greeting" }), + ], + }, + { noSendLogs: true, returnResults: true }, + ); + + expect(result.results).toHaveLength(1); + expect(result.results[0].scores.exact_match).toBe(1); + expect(result.results[0].classifications?.category).toEqual([ + { id: "greeting", label: "Greeting" }, + ]); +}); + +test("malformed classifier output fails clearly", async () => { + const result = await Eval( + "test-invalid-classifier-output", + { + data: [{ input: "hello" }], + task: (input) => input, + classifiers: [() => ({}) as never], + }, + { noSendLogs: true, returnResults: true }, + ); + + expect(result.results).toHaveLength(1); + expect((result.results[0] as any).metadata?.classifier_errors).toMatchObject({ + classifier_0: expect.stringMatching( + /must return classifications with a non-empty string name/, + ), + }); +}); diff --git a/js/src/framework.ts b/js/src/framework.ts index 637e688d9..c2bad700d 100644 --- a/js/src/framework.ts +++ b/js/src/framework.ts @@ -1,6 +1,8 @@ import { makeScorerPropagatedEvent, mergeDicts, + Classification, + ClassificationItem, Score, SpanComponentsV3, SpanTypeAttribute, @@ -186,6 +188,17 @@ export type EvalScorer< args: EvalScorerArgs, ) => OneOrMoreScores | Promise; +export type OneOrMoreClassifications = Classification | Classification[] | null; + +export type EvalClassifier< + Input, + Output, + Expected, + Metadata extends BaseMetadata = DefaultMetadataType, +> = ( + args: EvalScorerArgs, +) => OneOrMoreClassifications | Promise; + export type EvalResult< Input, Output, @@ -193,9 +206,10 @@ export type EvalResult< Metadata extends BaseMetadata = DefaultMetadataType, > = EvalCase & { output: Output; - scores: Record; error: unknown; origin?: ObjectReference; + scores: Record; + classifications?: Record; }; type ErrorScoreHandler = (args: { @@ -205,6 +219,10 @@ type ErrorScoreHandler = (args: { unhandledScores: string[]; }) => Record | undefined | void; +/** + * Defines an evaluator. At least one of `scores` or `classifiers` must be provided; + * a runtime error is raised if neither is present. + */ export interface Evaluator< Input, Output, @@ -223,9 +241,17 @@ export interface Evaluator< task: EvalTask; /** - * A set of functions that take an input, output, and expected value and return a score. + * A set of functions that take an input, output, and expected value and return a {@link Score}. + * At least one of `scores` or `classifiers` must be provided. */ - scores: EvalScorer[]; + scores?: EvalScorer[]; + + /** + * A set of functions that take an input, output, and expected value and return a + * {@link Classification}. Results are recorded under the `classifications` column. + * At least one of `scores` or `classifiers` must be provided. + */ + classifiers?: EvalClassifier[]; /** * A set of parameters that will be passed to the evaluator. @@ -864,6 +890,132 @@ export function scorerName( return scorer.name || `scorer_${scorer_idx}`; } +function classifierName( + classifier: EvalClassifier, + classifier_idx: number, +) { + return classifier.name || `classifier_${classifier_idx}`; +} + +function buildSpanMetadata( + results: Array<{ name: string; metadata?: Record }>, +) { + return results.length === 1 + ? results[0].metadata + : results.reduce( + (prev, s) => mergeDicts(prev, { [s.name]: s.metadata }), + {}, + ); +} + +function buildSpanScores( + results: Array<{ + name: string; + score: number | null; + metadata?: Record; + }>, +) { + const scoresRecord = results.reduce( + (prev, s) => mergeDicts(prev, { [s.name]: s.score }), + {}, + ); + return { resultMetadata: buildSpanMetadata(results), scoresRecord }; +} + +async function runInScorerSpan( + rootSpan: Span, + spanName: string, + spanType: SpanTypeAttribute, + propagatedEvent: ReturnType, + eventInput: unknown, + fn: (span: Span) => Promise, +): Promise< + { kind: "score"; value: T[] | null } | { kind: "error"; value: unknown } +> { + try { + const value = await rootSpan.traced(fn, { + name: spanName, + spanAttributes: { type: spanType, purpose: "scorer" }, + propagatedEvent, + event: { input: eventInput }, + }); + return { kind: "score", value }; + } catch (e) { + return { kind: "error", value: e }; + } +} + +function collectScoringResults( + runResults: Array< + { kind: "score"; value: T[] | null } | { kind: "error"; value: unknown } + >, + names: string[], + onResult: (result: T) => void, +): { name: string; error: unknown }[] { + const failing: { name: string; error: unknown }[] = []; + runResults.forEach((r, i) => { + if (r.kind === "score") { + (r.value ?? []).forEach(onResult); + } else { + failing.push({ name: names[i], error: r.value }); + } + }); + return failing; +} + +function validateClassificationResult( + value: unknown, + scorerName: string, +): Classification { + if (!(typeof value === "object" && value !== null && !isEmpty(value))) { + throw new Error( + `When returning structured classifier results, each classification must be a non-empty object. Got: ${JSON.stringify(value)}`, + ); + } + if (!("name" in value) || typeof value.name !== "string" || !value.name) { + throw new Error( + `Classifier ${scorerName} must return classifications with a non-empty string name. Got: ${JSON.stringify(value)}`, + ); + } + if (!("id" in value) || typeof value.id !== "string" || !value.id) { + throw new Error( + `Classifier ${scorerName} must return classifications with a non-empty string id. Got: ${JSON.stringify(value)}`, + ); + } + return value as Classification; +} + +function toClassificationItem(c: Classification): ClassificationItem { + return { + id: c.id, + label: c.label ?? c.id, + ...(c.metadata !== undefined ? { metadata: c.metadata } : {}), + }; +} + +function logScoringFailures( + kind: string, + failures: { name: string; error: unknown }[], + metadata: Record, + rootSpan: Span, + state: BraintrustState | undefined, +): string[] { + if (!failures.length) return []; + const errorMap = Object.fromEntries( + failures.map(({ name, error }) => [ + name, + error instanceof Error ? error.stack : `${error}`, + ]), + ); + metadata[`${kind}_errors`] = errorMap; + rootSpan.log({ metadata: { [`${kind}_errors`]: errorMap } }); + debugLogger.forState(state).warn( + `Found exceptions for the following ${kind}s: ${Object.keys(errorMap).join(", ")}`, + failures.map((f) => f.error), + ); + return Object.keys(errorMap); +} + export async function runEvaluator( experiment: Experiment | null, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -876,6 +1028,11 @@ export async function runEvaluator( enableCache = true, // eslint-disable-next-line @typescript-eslint/no-explicit-any ): Promise> { + if (!evaluator.scores && !evaluator.classifiers) { + throw new Error( + "Evaluator must include at least one of `scores` or `classifiers`", + ); + } return await runEvaluatorInternal( experiment, evaluator, @@ -1089,7 +1246,11 @@ async function runEvaluatorInternal( let error: unknown | undefined = undefined; let tags: string[] = [...(datum.tags ?? [])]; const scores: Record = {}; - const scorerNames = evaluator.scores.map(scorerName); + const classifications: Record = {}; + const scorerNames = (evaluator.scores ?? []).map(scorerName); + const classifierNames = (evaluator.classifiers ?? []).map( + classifierName, + ); let unhandledScores: string[] | null = scorerNames; try { const meta = (o: Record) => @@ -1154,139 +1315,156 @@ async function runEvaluatorInternal( output, trace, }; - const scoreResults = await Promise.all( - evaluator.scores.map(async (score, score_idx) => { - try { - const runScorer = async (span: Span) => { - const scoreResult = score(scoringArgs); - const scoreValue = - scoreResult instanceof Promise - ? await scoreResult - : scoreResult; - - if (scoreValue === null) { - return null; - } - - if (Array.isArray(scoreValue)) { - for (const s of scoreValue) { - if (!(typeof s === "object" && !isEmpty(s))) { - throw new Error( - `When returning an array of scores, each score must be a non-empty object. Got: ${JSON.stringify( - s, - )}`, - ); + const { trace: _trace, ...scoringArgsForLogging } = scoringArgs; + const propagatedEvent = makeScorerPropagatedEvent( + await rootSpan.export(), + ); + + const getOtherFields = (s: Score) => { + const { metadata: _metadata, name: _name, ...rest } = s; + return rest; + }; + + const [scoreResults, classificationResults] = await Promise.all([ + Promise.all( + (evaluator.scores ?? []).map((score, score_idx) => + runInScorerSpan( + rootSpan, + scorerNames[score_idx], + SpanTypeAttribute.SCORE, + propagatedEvent, + scoringArgsForLogging, + async (span) => { + const scoreValue = await Promise.resolve( + score(scoringArgs), + ); + if (scoreValue === null) return null; + if (Array.isArray(scoreValue)) { + for (const s of scoreValue) { + if (!(typeof s === "object" && !isEmpty(s))) { + throw new Error( + `When returning an array of scores, each score must be a non-empty object. Got: ${JSON.stringify(s)}`, + ); + } } } - } - - const results = Array.isArray(scoreValue) - ? scoreValue - : typeof scoreValue === "object" && !isEmpty(scoreValue) - ? [scoreValue] - : [ - { - name: scorerNames[score_idx], - score: scoreValue, - }, - ]; - - const getOtherFields = (s: Score) => { - const { metadata: _metadata, name: _name, ...rest } = s; - return rest; - }; - - const resultMetadata = - results.length === 1 - ? results[0].metadata - : results.reduce( - (prev, s) => - mergeDicts(prev, { - [s.name]: s.metadata, - }), - {}, - ); - - const resultOutput = - results.length === 1 - ? getOtherFields(results[0]) - : results.reduce( - (prev, s) => - mergeDicts(prev, { [s.name]: getOtherFields(s) }), - {}, - ); - - const scores = results.reduce( - (prev, s) => mergeDicts(prev, { [s.name]: s.score }), - {}, - ); - - span.log({ - output: resultOutput, - metadata: resultMetadata, - scores: scores, - }); - return results; - }; - - // Exclude trace from logged input since it contains internal state - // that shouldn't be serialized (spansFlushPromise, spansFlushed, etc.) - const { trace: _trace, ...scoringArgsForLogging } = - scoringArgs; - const results = await rootSpan.traced(runScorer, { - name: scorerNames[score_idx], - spanAttributes: { - type: SpanTypeAttribute.SCORE, - purpose: "scorer", + const results: Score[] = Array.isArray(scoreValue) + ? scoreValue + : typeof scoreValue === "object" && !isEmpty(scoreValue) + ? [scoreValue] + : [ + { + name: scorerNames[score_idx], + score: scoreValue, + }, + ]; + const { resultMetadata, scoresRecord } = + buildSpanScores(results); + const resultOutput = + results.length === 1 + ? getOtherFields(results[0]) + : results.reduce( + (prev, s) => + mergeDicts(prev, { + [s.name]: getOtherFields(s), + }), + {}, + ); + span.log({ + output: resultOutput, + metadata: resultMetadata, + scores: scoresRecord, + }); + return results; + }, + ), + ), + ), + Promise.all( + (evaluator.classifiers ?? []).map((classifier, idx) => + runInScorerSpan( + rootSpan, + classifierNames[idx], + SpanTypeAttribute.CLASSIFIER, + propagatedEvent, + scoringArgsForLogging, + async (span) => { + const classifierValue = await Promise.resolve( + classifier(scoringArgs), + ); + if (classifierValue === null) return null; + const rawResults = ( + Array.isArray(classifierValue) + ? classifierValue + : [classifierValue] + ).map((result) => + validateClassificationResult( + result, + classifierNames[idx], + ), + ); + const resultOutput = + rawResults.length === 1 + ? toClassificationItem(rawResults[0]) + : rawResults.reduce( + (prev, r) => + mergeDicts(prev, { + [r.name]: toClassificationItem(r), + }), + {}, + ); + span.log({ + output: resultOutput, + metadata: buildSpanMetadata(rawResults), + }); + return rawResults; }, - propagatedEvent: makeScorerPropagatedEvent( - await rootSpan.export(), - ), - event: { input: scoringArgsForLogging }, - }); - return { kind: "score", value: results } as const; - } catch (e) { - return { kind: "error", value: e } as const; + ), + ), + ), + ]); + + const failingScorers = collectScoringResults( + scoreResults, + scorerNames, + (result) => { + scores[result.name] = result.score; + }, + ); + + const failingClassifiers = collectScoringResults( + classificationResults, + classifierNames, + (result) => { + const item = toClassificationItem(result); + if (!classifications[result.name]) { + classifications[result.name] = []; } - }), + classifications[result.name].push(item); + }, ); - // Resolve each promise on its own so that we can separate the passing - // from the failing ones. - const failingScorersAndResults: { name: string; error: unknown }[] = - []; - scoreResults.forEach((results, i) => { - const name = scorerNames[i]; - if (results.kind === "score") { - (results.value || []).forEach((result) => { - scores[result.name] = result.score; - }); - } else { - failingScorersAndResults.push({ name, error: results.value }); - } - }); - unhandledScores = null; - if (failingScorersAndResults.length) { - const scorerErrors = Object.fromEntries( - failingScorersAndResults.map(({ name, error }) => [ - name, - error instanceof Error ? error.stack : `${error}`, - ]), - ); - metadata["scorer_errors"] = scorerErrors; - rootSpan.log({ - metadata: { scorer_errors: scorerErrors }, - }); - const names = Object.keys(scorerErrors).join(", "); - const errors = failingScorersAndResults.map((item) => item.error); - unhandledScores = Object.keys(scorerErrors); - debugLogger - .forState(evaluator.state) - .warn( - `Found exceptions for the following scorers: ${names}`, - errors, - ); + if (Object.keys(classifications).length > 0) { + rootSpan.log({ classifications }); } + + const failedScorerNames = logScoringFailures( + "scorer", + failingScorers, + metadata, + rootSpan, + evaluator.state, + ); + unhandledScores = failedScorerNames.length + ? failedScorerNames + : null; + logScoringFailures( + "classifier", + failingClassifiers, + metadata, + rootSpan, + evaluator.state, + ); } catch (e) { logSpanError(rootSpan, e); error = e; @@ -1310,15 +1488,21 @@ async function runEvaluatorInternal( } if (collectResults) { - collectedResults.push({ + const baseResult = { input: datum.input, ...("expected" in datum ? { expected: datum.expected } : {}), output, tags: tags.length ? tags : undefined, metadata, - scores: mergedScores, error, origin: baseEvent.event?.origin, + }; + collectedResults.push({ + ...baseResult, + scores: mergedScores, + ...(Object.keys(classifications).length > 0 + ? { classifications } + : {}), }); } }; diff --git a/js/src/parameters.test.ts b/js/src/parameters.test.ts index dbba8ea49..d5b7b7e4b 100644 --- a/js/src/parameters.test.ts +++ b/js/src/parameters.test.ts @@ -26,6 +26,7 @@ test("parameters are passed to task", async () => { return output; }, scores: [], + classifiers: [], parameters: { prefix: z.string().default("start:"), suffix: z.string().default(":end"), @@ -59,6 +60,7 @@ test("prompt parameter is passed correctly", async () => { return input; }, scores: [], + classifiers: [], parameters: { main: { type: "prompt", @@ -99,6 +101,7 @@ test("custom parameter values override defaults", async () => { return output; }, scores: [], + classifiers: [], parameters: { prefix: z.string().default("start:"), suffix: z.string().default(":end"), @@ -131,6 +134,7 @@ test("array parameter is handled correctly", async () => { return input; }, scores: [], + classifiers: [], parameters: { items: z.array(z.string()).default(["item1", "item2"]), }, @@ -161,6 +165,7 @@ test("object parameter is handled correctly", async () => { return input; }, scores: [], + classifiers: [], parameters: { config: z .object({ @@ -196,6 +201,7 @@ test("model parameter defaults to configured value", async () => { return input; }, scores: [], + classifiers: [], parameters: { model: { type: "model", @@ -224,6 +230,7 @@ test("model parameter is required when default is missing", async () => { data: [{ input: "test" }], task: async (input: string) => input, scores: [], + classifiers: [], parameters: { model: { type: "model", diff --git a/js/util/index.ts b/js/util/index.ts index 25a76cc03..52b082cc1 100644 --- a/js/util/index.ts +++ b/js/util/index.ts @@ -55,7 +55,13 @@ export { ensureNewDatasetRecord, } from "./object"; -export type { Score, Scorer, ScorerArgs } from "./score"; +export type { + Classification, + ClassificationItem, + Score, + Scorer, + ScorerArgs, +} from "./score"; export { constructJsonArray, deterministicReplacer } from "./json_util"; diff --git a/js/util/object.ts b/js/util/object.ts index 735f52960..fea8735d6 100644 --- a/js/util/object.ts +++ b/js/util/object.ts @@ -21,6 +21,7 @@ export type OtherExperimentLogFields = { error: unknown; tags: string[]; scores: Record; + classifications?: Record; metadata: Record; metrics: Record; datasetRecordId: string; diff --git a/js/util/score.ts b/js/util/score.ts index 758902344..08daebeef 100644 --- a/js/util/score.ts +++ b/js/util/score.ts @@ -1,3 +1,23 @@ +/** + * The result returned by a classifier function. Unlike `Score`, `id` is + * required and the span will be recorded as a classifier span. + */ +export interface Classification { + name: string; + id: string; + label?: string; + metadata?: Record; +} + +/** + * The serialized form of a classification stored in the `classifications` log record. + */ +export interface ClassificationItem { + id: string; + label: string; + metadata?: Record; +} + export interface Score { name: string; score: number | null;