From c98321d2283e94386507e63ade3c9188b3c7f50e Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 19 Dec 2024 13:33:15 -0800 Subject: [PATCH 1/5] init --- .../CompareEvaluationsPage.tsx | 6 +- .../pages/CompareEvaluationsPage/ecpState.ts | 33 ++-- .../pages/CompareEvaluationsPage/ecpTypes.ts | 35 ++-- .../ComparisonDefinitionSection.tsx | 12 +- .../EvaluationDefinition.tsx | 12 +- .../ExampleCompareSection.tsx | 12 +- .../exampleCompareSectionUtil.ts | 14 +- .../ExampleFilterSection.tsx | 28 +-- .../ScorecardSection/ScorecardSection.tsx | 18 +- .../SummaryPlotsSection.tsx | 8 +- .../tsDataModelHooksEvaluationComparison.ts | 175 ++++++++++++------ 11 files changed, 212 insertions(+), 141 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx index 478c4887546..a22994f5ffa 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx @@ -179,8 +179,8 @@ const CompareEvaluationsPageInner: React.FC<{ }> = props => { const {state, setSelectedMetrics} = useCompareEvaluationsState(); const showExampleFilter = - Object.keys(state.data.evaluationCalls).length === 2; - const showExamples = Object.keys(state.data.resultRows).length > 0; + Object.keys(state.summary.evaluationCalls).length === 2; + const showExamples = Object.keys(state.summary.resultRows).length > 0; return ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts index e5c1b03d60a..ffe901edf27 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts @@ -7,9 +7,9 @@ import {useMemo} from 'react'; -import {useEvaluationComparisonData} from '../wfReactInterface/tsDataModelHooksEvaluationComparison'; +import {useEvaluationComparisonResults, useEvaluationComparisonSummary} from '../wfReactInterface/tsDataModelHooksEvaluationComparison'; import {Loadable} from '../wfReactInterface/wfDataModelHooksInterface'; -import {EvaluationComparisonData} from './ecpTypes'; +import {EvaluationComparisonResults, EvaluationComparisonSummary} from './ecpTypes'; import {getMetricIds} from './ecpUtil'; /** @@ -17,7 +17,9 @@ import {getMetricIds} from './ecpUtil'; */ export type EvaluationComparisonState = { // The normalized data for the evaluations - data: EvaluationComparisonData; + summary: EvaluationComparisonSummary; + // The results of the evaluations + results: EvaluationComparisonResults | null; // The dimensions to compare & filter results comparisonDimensions?: ComparisonDimensionsType; // The current digest which is in view @@ -50,18 +52,19 @@ export const useEvaluationComparisonState = ( const orderedCallIds = useMemo(() => { return getCallIdsOrderedForQuery(evaluationCallIds); }, [evaluationCallIds]); - const data = useEvaluationComparisonData(entity, project, orderedCallIds); + const summaryData = useEvaluationComparisonSummary(entity, project, orderedCallIds); + const resultsData = useEvaluationComparisonResults(entity, project, orderedCallIds, summaryData.result); const value = useMemo(() => { - if (data.result == null || data.loading) { + if (summaryData.result == null || summaryData.loading) { return {loading: true, result: null}; } const scorerDimensions = Object.keys( - getMetricIds(data.result, 'score', 'scorer') + getMetricIds(summaryData.result, 'score', 'scorer') ); const derivedDimensions = Object.keys( - getMetricIds(data.result, 'score', 'derived') + getMetricIds(summaryData.result, 'score', 'derived') ); let newComparisonDimensions = comparisonDimensions; @@ -93,21 +96,15 @@ export const useEvaluationComparisonState = ( return { loading: false, result: { - data: data.result, + summary: summaryData.result, + results: resultsData.result, comparisonDimensions: newComparisonDimensions, selectedInputDigest, selectedMetrics, evaluationCallIdsOrdered: evaluationCallIds, }, }; - }, [ - data.result, - data.loading, - comparisonDimensions, - selectedInputDigest, - selectedMetrics, - evaluationCallIds, - ]); + }, [summaryData.result, summaryData.loading, comparisonDimensions, resultsData.result, selectedInputDigest, selectedMetrics, evaluationCallIds]); return value; }; @@ -132,8 +129,8 @@ const getCallIdsOrderedForQuery = (callIds: string[]) => { */ export const getOrderedModelRefs = (state: EvaluationComparisonState) => { const baselineCallId = getBaselineCallId(state); - const baselineRef = state.data.evaluationCalls[baselineCallId].modelRef; - const refs = Object.keys(state.data.models); + const baselineRef = state.summary.evaluationCalls[baselineCallId].modelRef; + const refs = Object.keys(state.summary.models); // Make sure the baseline model is first moveItemToFront(refs, baselineRef); return refs; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts index b4642fae240..2a505f3a16e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts @@ -7,7 +7,8 @@ */ import {TraceCallSchema} from '../wfReactInterface/traceServerClientTypes'; -export type EvaluationComparisonData = { + +export type EvaluationComparisonSummary = { // Entity and Project are constant across all calls entity: string; project: string; @@ -23,6 +24,20 @@ export type EvaluationComparisonData = { [callId: string]: EvaluationCall; }; + // Models are the Weave Objects used to define the model logic and properties. + models: { + [modelRef: string]: ModelObj; + }; + + // ScoreMetrics define the metrics that are associated on each individual prediction + scoreMetrics: MetricDefinitionMap; + + // SummaryMetrics define the metrics that are associated with the evaluation as a whole + // often aggregated from the scoreMetrics. + summaryMetrics: MetricDefinitionMap; +}; + +export type EvaluationComparisonResults = { // Inputs are the intersection of all inputs used in the evaluations. // Note, we are able to "merge" the same input digest even if it is // used in different evaluations. @@ -30,11 +45,6 @@ export type EvaluationComparisonData = { [rowDigest: string]: DatasetRow; }; - // Models are the Weave Objects used to define the model logic and properties. - models: { - [modelRef: string]: ModelObj; - }; - // ResultRows are the actual results of running the evaluation against // the inputs. resultRows: { @@ -53,16 +63,8 @@ export type EvaluationComparisonData = { }; }; }; - }; - - // ScoreMetrics define the metrics that are associated on each individual prediction - scoreMetrics: MetricDefinitionMap; - - // SummaryMetrics define the metrics that are associated with the evaluation as a whole - // often aggregated from the scoreMetrics. - summaryMetrics: MetricDefinitionMap; -}; - + } +} /** * The EvaluationObj is the primary object that defines the evaluation itself. */ @@ -84,6 +86,7 @@ export type EvaluationCall = { name: string; color: string; summaryMetrics: MetricResultMap; + traceId: string; }; /** diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx index 2704a66cbea..c32cda7fb50 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx @@ -44,9 +44,9 @@ export const ComparisonDefinitionSection: React.FC<{ return callIds.map(callId => ({ key: 'evaluations', value: callId, - label: props.state.data.evaluationCalls[callId]?.name ?? callId, + label: props.state.summary.evaluationCalls[callId]?.name ?? callId, })); - }, [callIds, props.state.data.evaluationCalls]); + }, [callIds, props.state.summary.evaluationCalls]); const onSetBaseline = (value: string | null) => { if (!value) { @@ -130,8 +130,8 @@ const AddEvaluationButton: React.FC<{ // Calls query for just evaluations const evaluationsFilter = useEvaluationsFilter( - props.state.data.entity, - props.state.data.project + props.state.summary.entity, + props.state.summary.project ); const page = useMemo( () => ({ @@ -144,8 +144,8 @@ const AddEvaluationButton: React.FC<{ // Don't query for output here, re-queried in tsDataModelHooksEvaluationComparison.ts const columns = useMemo(() => ['inputs', 'display_name'], []); const calls = useCallsForQuery( - props.state.data.entity, - props.state.data.project, + props.state.summary.entity, + props.state.summary.project, evaluationsFilter, DEFAULT_FILTER_CALLS, page, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx index 5dcf835e378..1894398e553 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx @@ -21,11 +21,11 @@ export const EvaluationCallLink: React.FC<{ callId: string; state: EvaluationComparisonState; }> = props => { - const evaluationCall = props.state.data.evaluationCalls?.[props.callId]; + const evaluationCall = props.state.summary.evaluationCalls?.[props.callId]; if (!evaluationCall) { return null; } - const {entity, project} = props.state.data; + const {entity, project} = props.state.summary; return ( = props => { const {useObjectVersion} = useWFHooks(); - const evaluationCall = props.state.data.evaluationCalls[props.callId]; - const modelObj = props.state.data.models[evaluationCall.modelRef]; + const evaluationCall = props.state.summary.evaluationCalls[props.callId]; + const modelObj = props.state.summary.models[evaluationCall.modelRef]; const objRef = useMemo( () => parseRef(modelObj.ref) as WeaveObjectRef, [modelObj.ref] @@ -95,9 +95,9 @@ export const EvaluationDatasetLink: React.FC<{ callId: string; state: EvaluationComparisonState; }> = props => { - const evaluationCall = props.state.data.evaluationCalls[props.callId]; + const evaluationCall = props.state.summary.evaluationCalls[props.callId]; const evaluationObj = - props.state.data.evaluations[evaluationCall.evaluationRef]; + props.state.summary.evaluations[evaluationCall.evaluationRef]; const parsed = parseRef(evaluationObj.datasetRef); if (!parsed) { return null; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx index 398f65ecd45..4af9900f2df 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx @@ -229,15 +229,15 @@ export const ExampleCompareSection: React.FC<{ }>({}); const onScorerClick = usePeekCall( - props.state.data.entity, - props.state.data.project + props.state.summary.entity, + props.state.summary.project ); const {ref1, ref2} = useLinkHorizontalScroll(); const compositeScoreMetrics = useMemo( - () => buildCompositeMetricsMap(props.state.data, 'score'), - [props.state.data] + () => buildCompositeMetricsMap(props.state.summary, 'score'), + [props.state.summary] ); if (target == null) { @@ -261,7 +261,7 @@ export const ExampleCompareSection: React.FC<{ const numEvals = numTrials.length; // Get derived scores, then filter out any not in the selected metrics const derivedScores = Object.values( - getMetricIds(props.state.data, 'score', 'derived') + getMetricIds(props.state.summary, 'score', 'derived') ).filter( score => props.state.selectedMetrics?.[flattenedDimensionPath(score)] ); @@ -483,7 +483,7 @@ export const ExampleCompareSection: React.FC<{ trialPredict?.op_name ?? '' )?.artifactName; const trialCallId = trialPredict?.id; - const evaluationCall = props.state.data.evaluationCalls[currEvalCallId]; + const evaluationCall = props.state.summary.evaluationCalls[currEvalCallId]; if (trialEntity && trialProject && trialOpName && trialCallId) { return ( { const leafDims = useMemo(() => getOrderedCallIds(state), [state]); const compositeMetricsMap = useMemo( - () => buildCompositeMetricsMap(state.data, 'score'), - [state.data] + () => buildCompositeMetricsMap(state.summary, 'score'), + [state.summary] ); const flattenedRows = useMemo(() => { const rows: FlattenedRow[] = []; - Object.entries(state.data.resultRows).forEach( + Object.entries(state.summary.resultRows).forEach( ([rowDigest, rowCollection]) => { Object.values(rowCollection.evaluations).forEach(modelCollection => { Object.values(modelCollection.predictAndScores).forEach( predictAndScoreRes => { const datasetRow = - state.data.inputs[predictAndScoreRes.rowDigest]; + state.summary.inputs[predictAndScoreRes.rowDigest]; if (datasetRow != null) { const output = predictAndScoreRes._rawPredictTraceData?.output; rows.push({ @@ -143,7 +143,7 @@ export const useFilteredAggregateRows = (state: EvaluationComparisonState) => { }), output: flattenObjectPreservingWeaveTypes({output}), scores: Object.fromEntries( - [...Object.entries(state.data.scoreMetrics)].map( + [...Object.entries(state.summary.scoreMetrics)].map( ([scoreKey, scoreVal]) => { return [ scoreKey, @@ -169,7 +169,7 @@ export const useFilteredAggregateRows = (state: EvaluationComparisonState) => { } ); return rows; - }, [state.data.resultRows, state.data.inputs, state.data.scoreMetrics]); + }, [state.summary.resultRows, state.summary.inputs, state.summary.scoreMetrics]); const pivotedRows = useMemo(() => { // Ok, so in this step we are going to pivot - diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx index 1146c8ea960..47552b94c53 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx @@ -99,12 +99,12 @@ const SingleDimensionFilter: React.FC<{ dimensionIndex: number; }> = props => { const compositeMetricsMap = useMemo(() => { - return buildCompositeMetricsMap(props.state.data, 'score'); - }, [props.state.data]); + return buildCompositeMetricsMap(props.state.summary, 'score'); + }, [props.state.summary]); const {setComparisonDimensions} = useCompareEvaluationsState(); const baselineCallId = getBaselineCallId(props.state); - const compareCallId = Object.keys(props.state.data.evaluationCalls).find( + const compareCallId = Object.keys(props.state.summary.evaluationCalls).find( callId => callId !== baselineCallId )!; @@ -112,14 +112,14 @@ const SingleDimensionFilter: React.FC<{ props.state.comparisonDimensions?.[props.dimensionIndex]; const targetDimension = targetComparisonDimension - ? props.state.data.scoreMetrics[targetComparisonDimension.metricId] + ? props.state.summary.scoreMetrics[targetComparisonDimension.metricId] : undefined; const xIsPercentage = targetDimension?.scoreType === 'binary'; const yIsPercentage = targetDimension?.scoreType === 'binary'; - const xColor = props.state.data.evaluationCalls[baselineCallId].color; - const yColor = props.state.data.evaluationCalls[compareCallId].color; + const xColor = props.state.summary.evaluationCalls[baselineCallId].color; + const yColor = props.state.summary.evaluationCalls[compareCallId].color; const {filteredRows} = useFilteredAggregateRows(props.state); const filteredDigest = useMemo(() => { @@ -141,7 +141,7 @@ const SingleDimensionFilter: React.FC<{ ); if (baselineTargetDimension != null && compareTargetDimension != null) { - Object.entries(props.state.data.resultRows).forEach(([digest, row]) => { + Object.entries(props.state.summary.resultRows).forEach(([digest, row]) => { const xVals: number[] = []; const yVals: number[] = []; Object.values( @@ -230,7 +230,7 @@ const SingleDimensionFilter: React.FC<{ compareCallId, compositeMetricsMap, filteredDigest, - props.state.data.resultRows, + props.state.summary.resultRows, targetDimension, ]); @@ -281,15 +281,15 @@ const SingleDimensionFilter: React.FC<{ yIsPercentage={yIsPercentage} xTitle={ 'Baseline: ' + - props.state.data.evaluationCalls[baselineCallId].name + + props.state.summary.evaluationCalls[baselineCallId].name + ' ' + - props.state.data.evaluationCalls[baselineCallId].callId.slice(-4) + props.state.summary.evaluationCalls[baselineCallId].callId.slice(-4) } yTitle={ 'Challenger: ' + - props.state.data.evaluationCalls[compareCallId].name + + props.state.summary.evaluationCalls[compareCallId].name + ' ' + - props.state.data.evaluationCalls[compareCallId].callId.slice(-4) + props.state.summary.evaluationCalls[compareCallId].callId.slice(-4) } /> @@ -303,11 +303,11 @@ const DimensionPicker: React.FC<{ props.state.comparisonDimensions?.[props.dimensionIndex]; const currDimension = targetComparisonDimension - ? props.state.data.scoreMetrics[targetComparisonDimension.metricId] + ? props.state.summary.scoreMetrics[targetComparisonDimension.metricId] : undefined; const {setComparisonDimensions} = useCompareEvaluationsState(); - const dimensionMap = props.state.data.scoreMetrics; + const dimensionMap = props.state.summary.scoreMetrics; return ( Object.values(props.state.data.evaluations).map(e => e.datasetRef), + () => Object.values(props.state.summary.evaluations).map(e => e.datasetRef), [props.state] ); const evalCallIds = useMemo( @@ -89,7 +89,7 @@ export const ScorecardSection: React.FC<{ const modelProps = useMemo(() => { const propsRes: {[prop: string]: {[ref: string]: any}} = {}; modelRefs.forEach(ref => { - const model = props.state.data.models[ref]; + const model = props.state.summary.models[ref]; Object.keys(model.properties).forEach(prop => { if (!propsRes[prop]) { propsRes[prop] = {}; @@ -100,7 +100,7 @@ export const ScorecardSection: React.FC<{ // Make sure predict op is last modelRefs.forEach(ref => { - const model = props.state.data.models[ref]; + const model = props.state.summary.models[ref]; if (!propsRes.predict) { propsRes.predict = {}; } @@ -108,7 +108,7 @@ export const ScorecardSection: React.FC<{ }); return propsRes; - }, [modelRefs, props.state.data.models]); + }, [modelRefs, props.state.summary.models]); const propsWithDifferences = useMemo(() => { return Object.keys(modelProps).filter(prop => { const values = Object.values(modelProps[prop]); @@ -119,15 +119,15 @@ export const ScorecardSection: React.FC<{ const compositeSummaryMetrics = useMemo(() => { return buildCompositeMetricsMap( - props.state.data, + props.state.summary, 'summary', props.state.selectedMetrics ); }, [props.state]); const onCallClick = usePeekCall( - props.state.data.entity, - props.state.data.project + props.state.summary.entity, + props.state.summary.project ); const datasetVariation = Array.from(new Set(datasetRefs)).length > 1; @@ -295,7 +295,7 @@ export const ScorecardSection: React.FC<{ {evalCallIds.map((evalCallId, mNdx) => { const model = - props.state.data.evaluationCalls[evalCallId].modelRef; + props.state.summary.evaluationCalls[evalCallId].modelRef; const parsed = parseRefMaybe( modelProps[prop][model] ) as WeaveObjectRef; @@ -560,7 +560,7 @@ const resolveSummaryMetricResult = ( const baseline = baselineDimension ? resolveSummaryMetricResultForEvaluateCall( baselineDimension, - state.data.evaluationCalls[evalCallId] + state.summary.evaluationCalls[evalCallId] ) : undefined; return baseline; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx index 02c456df850..2b1250080f3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx @@ -430,7 +430,7 @@ const usePlotDataFromMetrics = ( state: EvaluationComparisonState ): {radarData: RadarPlotData; allMetricNames: Set} => { const compositeMetrics = useMemo(() => { - return buildCompositeMetricsMap(state.data, 'summary'); + return buildCompositeMetricsMap(state.summary, 'summary'); }, [state]); const callIds = useMemo(() => { return getOrderedCallIds(state); @@ -450,7 +450,7 @@ const usePlotDataFromMetrics = ( } const val = resolveSummaryMetricValueForEvaluateCall( metricDimension, - state.data.evaluationCalls[callId] + state.summary.evaluationCalls[callId] ); if (typeof val === 'boolean') { return val ? 1 : 0; @@ -471,7 +471,7 @@ const usePlotDataFromMetrics = ( }); const radarData = Object.fromEntries( callIds.map(callId => { - const evalCall = state.data.evaluationCalls[callId]; + const evalCall = state.summary.evaluationCalls[callId]; return [ evalCall.callId, { @@ -491,5 +491,5 @@ const usePlotDataFromMetrics = ( ); const allMetricNames = new Set(metrics.map(m => m.metricLabel)); return {radarData, allMetricNames}; - }, [callIds, compositeMetrics, state.data.evaluationCalls]); + }, [callIds, compositeMetrics, state.summary.evaluationCalls]); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts index ba8941a24c5..f4d565bcdd7 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts @@ -77,7 +77,8 @@ import {useDeepMemo} from '../../../../../../hookUtils'; import {parseRef, WeaveObjectRef} from '../../../../../../react'; import {PREDICT_AND_SCORE_OP_NAME_POST_PYDANTIC} from '../common/heuristics'; import { - EvaluationComparisonData, + EvaluationComparisonResults, + EvaluationComparisonSummary, MetricDefinition, } from '../CompareEvaluationsPage/ecpTypes'; import { @@ -94,24 +95,25 @@ import { import {Loadable} from '../wfReactInterface/wfDataModelHooksInterface'; import {TraceCallSchema} from './traceServerClientTypes'; + /** * Primary react hook for fetching evaluation comparison data. This could be * moved into the Trace Server hooks at some point, hence the location of the file. */ -export const useEvaluationComparisonData = ( +export const useEvaluationComparisonSummary = ( entity: string, project: string, - evaluationCallIds: string[] -): Loadable => { + evaluationCallIds: string[], +): Loadable => { const getTraceServerClient = useGetTraceServerClientContext(); - const [data, setData] = useState(null); + const [data, setData] = useState(null); const evaluationCallIdsMemo = useDeepMemo(evaluationCallIds); const evaluationCallIdsRef = useRef(evaluationCallIdsMemo); useEffect(() => { setData(null); let mounted = true; - fetchEvaluationComparisonData( + fetchEvaluationSummaryData( getTraceServerClient(), entity, project, @@ -139,25 +141,69 @@ export const useEvaluationComparisonData = ( }; /** - * This function is responsible for building the data structure used to describe - * the comparison of evaluations. It is a complex function that fetches data from - * the trace server and builds a normalized data structure. + * Primary react hook for fetching evaluation comparison data. This could be + * moved into the Trace Server hooks at some point, hence the location of the file. */ -const fetchEvaluationComparisonData = async ( +export const useEvaluationComparisonResults = ( + entity: string, + project: string, + evaluationCallIds: string[], + summaryData: EvaluationComparisonSummary | null +): Loadable => { + const getTraceServerClient = useGetTraceServerClientContext(); + const [data, setData] = useState(null); + const evaluationCallIdsMemo = useDeepMemo(evaluationCallIds); + const evaluationCallIdsRef = useRef(evaluationCallIdsMemo); + + useEffect(() => { + setData(null); + let mounted = true; + if (summaryData == null) { + return; + } + fetchEvaluationComparisonResults( + getTraceServerClient(), + entity, + project, + evaluationCallIdsMemo, + summaryData + ).then(dataRes => { + if (mounted) { + evaluationCallIdsRef.current = evaluationCallIdsMemo; + setData(dataRes); + } + }); + return () => { + mounted = false; + }; + }, [entity, evaluationCallIdsMemo, project, getTraceServerClient, summaryData]); + + return useMemo(() => { + if ( + data == null || + evaluationCallIdsRef.current !== evaluationCallIdsMemo + ) { + return {loading: true, result: null}; + } + return {loading: false, result: data}; + }, [data, evaluationCallIdsMemo]); +}; + + + +const fetchEvaluationSummaryData = async ( traceServerClient: TraceServerClient, // TODO: Bad that this is leaking into user-land entity: string, project: string, evaluationCallIds: string[] -): Promise => { +): Promise => { const projectId = projectIdFromParts({entity, project}); - const result: EvaluationComparisonData = { + const result: EvaluationComparisonSummary = { entity, project, evaluationCalls: {}, evaluations: {}, - inputs: {}, models: {}, - resultRows: {}, scoreMetrics: {}, summaryMetrics: {}, }; @@ -175,36 +221,6 @@ const fetchEvaluationComparisonData = async ( filter: {call_ids: evaluationCallIds}, }); - // Kick off the trace query to get the actual trace data - // Note: we split this into 2 steps to ensure we only get level 2 children - // of the evaluations. This avoids massive overhead of fetching gigantic traces - // for every evaluation. - const evalTraceIds = evalRes.calls.map(call => call.trace_id); - // First, get all the children of the evaluations (predictAndScoreCalls + summary) - const evalTraceResProm = traceServerClient - .callsStreamQuery({ - project_id: projectId, - filter: {trace_ids: evalTraceIds, parent_ids: evaluationCallIds}, - }) - .then(predictAndScoreCallRes => { - // Then, get all the children of those calls (predictions + scores) - const predictAndScoreIds = predictAndScoreCallRes.calls.map( - call => call.id - ); - return traceServerClient - .callsStreamQuery({ - project_id: projectId, - filter: {trace_ids: evalTraceIds, parent_ids: predictAndScoreIds}, - }) - .then(predictionsAndScoresCallsRes => { - return { - calls: [ - ...predictAndScoreCallRes.calls, - ...predictionsAndScoresCallsRes.calls, - ], - }; - }); - }); const evaluationCallCache: {[callId: string]: EvaluationEvaluateCallSchema} = Object.fromEntries( @@ -220,6 +236,7 @@ const fetchEvaluationComparisonData = async ( evaluationRef: call.inputs.self, modelRef: call.inputs.model, summaryMetrics: {}, // These cannot be filled out yet since we don't know the IDs yet + traceId: call.trace_id, }, ]) ); @@ -415,9 +432,61 @@ const fetchEvaluationComparisonData = async ( }) ); + return result; +}; + +/** + * This function is responsible for building the data structure used to describe + * the comparison of evaluations. It is a complex function that fetches data from + * the trace server and builds a normalized data structure. + */ +const fetchEvaluationComparisonResults = async ( + traceServerClient: TraceServerClient, // TODO: Bad that this is leaking into user-land + entity: string, + project: string, + evaluationCallIds: string[], + summaryData: EvaluationComparisonSummary, +): Promise => { + const projectId = projectIdFromParts({entity, project}); + const result: EvaluationComparisonResults = { + inputs: {}, + resultRows: {}, + }; + + // Kick off the trace query to get the actual trace data + // Note: we split this into 2 steps to ensure we only get level 2 children + // of the evaluations. This avoids massive overhead of fetching gigantic traces + // for every evaluation. + const evalTraceIds = Object.values(summaryData.evaluationCalls).map(call => call.traceId); + // First, get all the children of the evaluations (predictAndScoreCalls + summary) + const evalTraceResProm = traceServerClient + .callsStreamQuery({ + project_id: projectId, + filter: {trace_ids: evalTraceIds, parent_ids: evaluationCallIds}, + }) + .then(predictAndScoreCallRes => { + // Then, get all the children of those calls (predictions + scores) + const predictAndScoreIds = predictAndScoreCallRes.calls.map( + call => call.id + ); + return traceServerClient + .callsStreamQuery({ + project_id: projectId, + filter: {trace_ids: evalTraceIds, parent_ids: predictAndScoreIds}, + }) + .then(predictionsAndScoresCallsRes => { + return { + calls: [ + ...predictAndScoreCallRes.calls, + ...predictionsAndScoresCallsRes.calls, + ], + }; + }); + }); + // 3.5 Populate the inputs // We only ned 1 since we are going to effectively do an inner join on the rowDigest - const datasetRef = Object.values(result.evaluations)[0].datasetRef as string; + const datasetRef = Object.values(summaryData.evaluations)[0].datasetRef as string; const datasetObjRes = await traceServerClient.readBatch({refs: [datasetRef]}); const rowsRef = datasetObjRes.vals[0].rows; const parsedRowsRef = parseRef(rowsRef) as WeaveObjectRef; @@ -440,7 +509,7 @@ const fetchEvaluationComparisonData = async ( // Create a set of all of the scorer refs const scorerRefs = new Set( - Object.values(result.evaluations).flatMap( + Object.values(summaryData.evaluations).flatMap( evaluation => evaluation.scorerRefs ) ); @@ -464,14 +533,14 @@ const fetchEvaluationComparisonData = async ( // Fill in the autosummary source calls summaryOps.forEach(summarizedOp => { const evalCallId = summarizedOp.parent_id!; - const evalCall = result.evaluationCalls[evalCallId]; + const evalCall = summaryData.evaluationCalls[evalCallId]; if (evalCall == null) { return; } Object.entries(evalCall.summaryMetrics).forEach( ([metricId, metricResult]) => { if ( - result.summaryMetrics[metricId].source === 'scorer' || + summaryData.summaryMetrics[metricId].source === 'scorer' || // Special case that the model latency is also a summary metric calc metricDefinitionId(modelLatencyMetricDimension) === metricId ) { @@ -481,6 +550,8 @@ const fetchEvaluationComparisonData = async ( ); }); + const modelRefs = Object.values(summaryData.evaluationCalls).map(evalCall => evalCall.modelRef); + // Next, we need to build the predictions object evalTraceRes.calls.forEach(traceCall => { // We are looking for 2 types of calls: @@ -599,7 +670,7 @@ const fetchEvaluationComparisonData = async ( scorerOpOrObjRef: scorerRef, }; const metricId = metricDefinitionId(metricDimension); - result.scoreMetrics[metricId] = metricDimension; + summaryData.scoreMetrics[metricId] = metricDimension; predictAndScoreFinal.scoreMetrics[metricId] = { sourceCallId: traceCall.id, value: scoreVal, @@ -612,7 +683,7 @@ const fetchEvaluationComparisonData = async ( scorerOpOrObjRef: scorerRef, }; const metricId = metricDefinitionId(metricDimension); - result.scoreMetrics[metricId] = metricDimension; + summaryData.scoreMetrics[metricId] = metricDimension; predictAndScoreFinal.scoreMetrics[metricId] = { sourceCallId: traceCall.id, @@ -648,7 +719,7 @@ const fetchEvaluationComparisonData = async ( if (isSummaryChild && isProbablyBoundScoreCall && isSummaryOp) { // Now fill in the source of the eval score const evalCallId = maybeParentSummaryOp!.parent_id!; - const evalCall = result.evaluationCalls[evalCallId]; + const evalCall = summaryData.evaluationCalls[evalCallId]; if (evalCall == null) { return; } @@ -668,7 +739,7 @@ const fetchEvaluationComparisonData = async ( Object.entries(result.resultRows).filter(([digest, row]) => { return ( Object.values(row.evaluations).length === - Object.values(result.evaluationCalls).length + Object.values(summaryData.evaluationCalls).length ); }) ); From f3839e8d48df2b2f24ea7a072c67b79c758fd31d Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 19 Dec 2024 15:05:56 -0800 Subject: [PATCH 2/5] more work --- .../CompareEvaluationsPage.tsx | 3 +- .../compositeMetricsUtil.ts | 18 ++-- .../pages/CompareEvaluationsPage/ecpState.ts | 33 +++++-- .../pages/CompareEvaluationsPage/ecpTypes.ts | 5 +- .../pages/CompareEvaluationsPage/ecpUtil.ts | 8 +- .../exampleCompareSectionUtil.ts | 10 ++- .../ExampleFilterSection.tsx | 88 ++++++++++--------- .../tsDataModelHooksEvaluationComparison.ts | 29 +++--- 8 files changed, 115 insertions(+), 79 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx index a22994f5ffa..6679f3009a3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx @@ -178,9 +178,10 @@ const CompareEvaluationsPageInner: React.FC<{ height: number; }> = props => { const {state, setSelectedMetrics} = useCompareEvaluationsState(); + console.log(state); const showExampleFilter = Object.keys(state.summary.evaluationCalls).length === 2; - const showExamples = Object.keys(state.summary.resultRows).length > 0; + const showExamples = Object.keys(state.results?.resultRows ?? {}).length > 0; return ( | undefined = undefined ): CompositeScoreMetrics => { @@ -83,9 +83,9 @@ export const buildCompositeMetricsMap = ( // Get the metric definition map based on the metric type let metricDefinitionMap; if (mType === 'score') { - metricDefinitionMap = data.scoreMetrics; + metricDefinitionMap = summaryData.scoreMetrics; } else if (mType === 'summary') { - metricDefinitionMap = data.summaryMetrics; + metricDefinitionMap = summaryData.summaryMetrics; } else { throw new Error(`Invalid metric type: ${mType}`); } @@ -128,9 +128,9 @@ export const buildCompositeMetricsMap = ( }; } - const evals = Object.values(data.evaluationCalls) + const evals = Object.values(summaryData.evaluationCalls) .filter(evaluationCall => { - const evaluation = data.evaluations[evaluationCall.evaluationRef]; + const evaluation = summaryData.evaluations[evaluationCall.evaluationRef]; return ( metric.scorerOpOrObjRef == null || evaluation.scorerRefs.includes(metric.scorerOpOrObjRef) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts index ffe901edf27..f7f5cca0768 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts @@ -7,9 +7,15 @@ import {useMemo} from 'react'; -import {useEvaluationComparisonResults, useEvaluationComparisonSummary} from '../wfReactInterface/tsDataModelHooksEvaluationComparison'; +import { + useEvaluationComparisonResults, + useEvaluationComparisonSummary, +} from '../wfReactInterface/tsDataModelHooksEvaluationComparison'; import {Loadable} from '../wfReactInterface/wfDataModelHooksInterface'; -import {EvaluationComparisonResults, EvaluationComparisonSummary} from './ecpTypes'; +import { + EvaluationComparisonResults, + EvaluationComparisonSummary, +} from './ecpTypes'; import {getMetricIds} from './ecpUtil'; /** @@ -52,8 +58,17 @@ export const useEvaluationComparisonState = ( const orderedCallIds = useMemo(() => { return getCallIdsOrderedForQuery(evaluationCallIds); }, [evaluationCallIds]); - const summaryData = useEvaluationComparisonSummary(entity, project, orderedCallIds); - const resultsData = useEvaluationComparisonResults(entity, project, orderedCallIds, summaryData.result); + const summaryData = useEvaluationComparisonSummary( + entity, + project, + orderedCallIds + ); + const resultsData = useEvaluationComparisonResults( + entity, + project, + orderedCallIds, + summaryData.result + ); const value = useMemo(() => { if (summaryData.result == null || summaryData.loading) { @@ -104,7 +119,15 @@ export const useEvaluationComparisonState = ( evaluationCallIdsOrdered: evaluationCallIds, }, }; - }, [summaryData.result, summaryData.loading, comparisonDimensions, resultsData.result, selectedInputDigest, selectedMetrics, evaluationCallIds]); + }, [ + summaryData.result, + summaryData.loading, + comparisonDimensions, + resultsData.result, + selectedInputDigest, + selectedMetrics, + evaluationCallIds, + ]); return value; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts index 2a505f3a16e..df40b28a92d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts @@ -7,7 +7,6 @@ */ import {TraceCallSchema} from '../wfReactInterface/traceServerClientTypes'; - export type EvaluationComparisonSummary = { // Entity and Project are constant across all calls entity: string; @@ -63,8 +62,8 @@ export type EvaluationComparisonResults = { }; }; }; - } -} + }; +}; /** * The EvaluationObj is the primary object that defines the evaluation itself. */ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts index 601116db34d..8235cdbb1c4 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts @@ -1,5 +1,5 @@ /** - * This file contains a handful of utilities for working with the `EvaluationComparisonData` destructure. + * This file contains a handful of utilities for working with the `EvaluationComparisonSummary` destructure. * These are mostly convenience functions for extracting and resolving metrics from the data, but also * include some helper functions for working with the `MetricDefinition` objects and constructing * strings correctly. @@ -8,7 +8,7 @@ import {parseRef, WeaveObjectRef} from '../../../../../../react'; import { EvaluationCall, - EvaluationComparisonData, + EvaluationComparisonSummary, MetricDefinition, MetricDefinitionMap, MetricResult, @@ -79,11 +79,11 @@ export const resolveSummaryMetricValueForEvaluateCall = ( }; export const getMetricIds = ( - data: EvaluationComparisonData, + summaryData: EvaluationComparisonSummary, type: MetricType, source: SourceType ): MetricDefinitionMap => { - const metrics = type === 'score' ? data.scoreMetrics : data.summaryMetrics; + const metrics = type === 'score' ? summaryData.scoreMetrics : summaryData.summaryMetrics; return Object.fromEntries( Object.entries(metrics).filter(([k, v]) => v.source === source) ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts index 9f4e56e9360..76f9ceff775 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts @@ -124,13 +124,13 @@ export const useFilteredAggregateRows = (state: EvaluationComparisonState) => { const flattenedRows = useMemo(() => { const rows: FlattenedRow[] = []; - Object.entries(state.summary.resultRows).forEach( + Object.entries(state.results?.resultRows ?? {}).forEach( ([rowDigest, rowCollection]) => { Object.values(rowCollection.evaluations).forEach(modelCollection => { Object.values(modelCollection.predictAndScores).forEach( predictAndScoreRes => { const datasetRow = - state.summary.inputs[predictAndScoreRes.rowDigest]; + state.results?.inputs[predictAndScoreRes.rowDigest]; if (datasetRow != null) { const output = predictAndScoreRes._rawPredictTraceData?.output; rows.push({ @@ -169,7 +169,11 @@ export const useFilteredAggregateRows = (state: EvaluationComparisonState) => { } ); return rows; - }, [state.summary.resultRows, state.summary.inputs, state.summary.scoreMetrics]); + }, [ + state.results?.resultRows, + state.results?.inputs, + state.summary.scoreMetrics, + ]); const pivotedRows = useMemo(() => { // Ok, so in this step we are going to pivot - diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx index 47552b94c53..1c819c4ee16 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx @@ -141,51 +141,53 @@ const SingleDimensionFilter: React.FC<{ ); if (baselineTargetDimension != null && compareTargetDimension != null) { - Object.entries(props.state.summary.resultRows).forEach(([digest, row]) => { - const xVals: number[] = []; - const yVals: number[] = []; - Object.values( - row.evaluations[baselineCallId]?.predictAndScores ?? {} - ).forEach(score => { - const val = resolveScoreMetricResultForPASCall( - baselineTargetDimension, - score - ); - if (val == null) { + Object.entries(props.state.results?.resultRows ?? {}).forEach( + ([digest, row]) => { + const xVals: number[] = []; + const yVals: number[] = []; + Object.values( + row.evaluations[baselineCallId]?.predictAndScores ?? {} + ).forEach(score => { + const val = resolveScoreMetricResultForPASCall( + baselineTargetDimension, + score + ); + if (val == null) { + return; + } else if (isBinaryScore(val.value)) { + xVals.push(val.value ? 1 : 0); + } else if (isContinuousScore(val.value)) { + xVals.push(val.value); + } + }); + Object.values( + row.evaluations[compareCallId]?.predictAndScores ?? {} + ).forEach(score => { + const val = resolveScoreMetricResultForPASCall( + compareTargetDimension, + score + ); + if (val == null) { + return; + } else if (isBinaryScore(val.value)) { + yVals.push(val.value ? 1 : 0); + } else if (isContinuousScore(val.value)) { + yVals.push(val.value); + } + }); + if (xVals.length === 0 || yVals.length === 0) { return; - } else if (isBinaryScore(val.value)) { - xVals.push(val.value ? 1 : 0); - } else if (isContinuousScore(val.value)) { - xVals.push(val.value); } - }); - Object.values( - row.evaluations[compareCallId]?.predictAndScores ?? {} - ).forEach(score => { - const val = resolveScoreMetricResultForPASCall( - compareTargetDimension, - score - ); - if (val == null) { - return; - } else if (isBinaryScore(val.value)) { - yVals.push(val.value ? 1 : 0); - } else if (isContinuousScore(val.value)) { - yVals.push(val.value); - } - }); - if (xVals.length === 0 || yVals.length === 0) { - return; + series.push({ + x: mean(xVals), + y: mean(yVals), + count: xVals.length, + size: MIN_PLOT_DOT_SIZE, + color: MOON_500, + selected: filteredDigest.has(digest), + }); } - series.push({ - x: mean(xVals), - y: mean(yVals), - count: xVals.length, - size: MIN_PLOT_DOT_SIZE, - color: MOON_500, - selected: filteredDigest.has(digest), - }); - }); + ); if (targetDimension.scoreType === 'binary') { // Here we are going to further group the points by their x and y values @@ -230,7 +232,7 @@ const SingleDimensionFilter: React.FC<{ compareCallId, compositeMetricsMap, filteredDigest, - props.state.summary.resultRows, + props.state.results?.resultRows, targetDimension, ]); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts index f4d565bcdd7..cb092a6c1d3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts @@ -95,7 +95,6 @@ import { import {Loadable} from '../wfReactInterface/wfDataModelHooksInterface'; import {TraceCallSchema} from './traceServerClientTypes'; - /** * Primary react hook for fetching evaluation comparison data. This could be * moved into the Trace Server hooks at some point, hence the location of the file. @@ -103,7 +102,7 @@ import {TraceCallSchema} from './traceServerClientTypes'; export const useEvaluationComparisonSummary = ( entity: string, project: string, - evaluationCallIds: string[], + evaluationCallIds: string[] ): Loadable => { const getTraceServerClient = useGetTraceServerClientContext(); const [data, setData] = useState(null); @@ -176,7 +175,13 @@ export const useEvaluationComparisonResults = ( return () => { mounted = false; }; - }, [entity, evaluationCallIdsMemo, project, getTraceServerClient, summaryData]); + }, [ + entity, + evaluationCallIdsMemo, + project, + getTraceServerClient, + summaryData, + ]); return useMemo(() => { if ( @@ -189,8 +194,6 @@ export const useEvaluationComparisonResults = ( }, [data, evaluationCallIdsMemo]); }; - - const fetchEvaluationSummaryData = async ( traceServerClient: TraceServerClient, // TODO: Bad that this is leaking into user-land entity: string, @@ -221,7 +224,6 @@ const fetchEvaluationSummaryData = async ( filter: {call_ids: evaluationCallIds}, }); - const evaluationCallCache: {[callId: string]: EvaluationEvaluateCallSchema} = Object.fromEntries( evalRes.calls.map(call => [call.id, call as EvaluationEvaluateCallSchema]) @@ -445,19 +447,21 @@ const fetchEvaluationComparisonResults = async ( entity: string, project: string, evaluationCallIds: string[], - summaryData: EvaluationComparisonSummary, + summaryData: EvaluationComparisonSummary ): Promise => { const projectId = projectIdFromParts({entity, project}); const result: EvaluationComparisonResults = { inputs: {}, resultRows: {}, }; - + // Kick off the trace query to get the actual trace data // Note: we split this into 2 steps to ensure we only get level 2 children // of the evaluations. This avoids massive overhead of fetching gigantic traces // for every evaluation. - const evalTraceIds = Object.values(summaryData.evaluationCalls).map(call => call.traceId); + const evalTraceIds = Object.values(summaryData.evaluationCalls).map( + call => call.traceId + ); // First, get all the children of the evaluations (predictAndScoreCalls + summary) const evalTraceResProm = traceServerClient .callsStreamQuery({ @@ -486,7 +490,8 @@ const fetchEvaluationComparisonResults = async ( // 3.5 Populate the inputs // We only ned 1 since we are going to effectively do an inner join on the rowDigest - const datasetRef = Object.values(summaryData.evaluations)[0].datasetRef as string; + const datasetRef = Object.values(summaryData.evaluations)[0] + .datasetRef as string; const datasetObjRes = await traceServerClient.readBatch({refs: [datasetRef]}); const rowsRef = datasetObjRes.vals[0].rows; const parsedRowsRef = parseRef(rowsRef) as WeaveObjectRef; @@ -550,7 +555,9 @@ const fetchEvaluationComparisonResults = async ( ); }); - const modelRefs = Object.values(summaryData.evaluationCalls).map(evalCall => evalCall.modelRef); + const modelRefs = Object.values(summaryData.evaluationCalls).map( + evalCall => evalCall.modelRef + ); // Next, we need to build the predictions object evalTraceRes.calls.forEach(traceCall => { From b2f1b1f0d93ec8ececae9bc92cf1cbd87ad80f45 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 19 Dec 2024 15:07:04 -0800 Subject: [PATCH 3/5] lint --- .../pages/CompareEvaluationsPage/compositeMetricsUtil.ts | 3 ++- .../Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compositeMetricsUtil.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compositeMetricsUtil.ts index a1d0679176b..7d4be894867 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compositeMetricsUtil.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compositeMetricsUtil.ts @@ -130,7 +130,8 @@ export const buildCompositeMetricsMap = ( const evals = Object.values(summaryData.evaluationCalls) .filter(evaluationCall => { - const evaluation = summaryData.evaluations[evaluationCall.evaluationRef]; + const evaluation = + summaryData.evaluations[evaluationCall.evaluationRef]; return ( metric.scorerOpOrObjRef == null || evaluation.scorerRefs.includes(metric.scorerOpOrObjRef) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts index 8235cdbb1c4..b29743dcd40 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpUtil.ts @@ -83,7 +83,8 @@ export const getMetricIds = ( type: MetricType, source: SourceType ): MetricDefinitionMap => { - const metrics = type === 'score' ? summaryData.scoreMetrics : summaryData.summaryMetrics; + const metrics = + type === 'score' ? summaryData.scoreMetrics : summaryData.summaryMetrics; return Object.fromEntries( Object.entries(metrics).filter(([k, v]) => v.source === source) ); From d18bad1933560cdd7d82c43672ef088140ae1ae7 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Thu, 19 Dec 2024 15:24:38 -0800 Subject: [PATCH 4/5] Checkpoint: fast(er) initial load --- .../CompareEvaluationsPage.tsx | 19 ++++++++++++++++--- .../pages/CompareEvaluationsPage/ecpState.ts | 6 +++--- .../exampleCompareSectionUtil.ts | 8 ++++---- .../ExampleFilterSection.tsx | 4 ++-- 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx index 6679f3009a3..448a0619601 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx @@ -4,6 +4,7 @@ import {Box} from '@material-ui/core'; import {Alert} from '@mui/material'; +import {WaveLoader} from '@wandb/weave/components/Loaders/WaveLoader'; import {Tailwind} from '@wandb/weave/components/Tailwind'; import {maybePluralizeWord} from '@wandb/weave/core/util/string'; import React, {FC, useCallback, useContext, useMemo, useState} from 'react'; @@ -178,10 +179,11 @@ const CompareEvaluationsPageInner: React.FC<{ height: number; }> = props => { const {state, setSelectedMetrics} = useCompareEvaluationsState(); - console.log(state); const showExampleFilter = Object.keys(state.summary.evaluationCalls).length === 2; - const showExamples = Object.keys(state.results?.resultRows ?? {}).length > 0; + const showExamples = + Object.keys(state.results.result?.resultRows ?? {}).length > 0; + const resultsLoading = state.results.loading; return ( - {showExamples ? ( + {resultsLoading ? ( + + + + ) : showExamples ? ( <> {showExampleFilter && } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts index f7f5cca0768..246876cb444 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts @@ -25,7 +25,7 @@ export type EvaluationComparisonState = { // The normalized data for the evaluations summary: EvaluationComparisonSummary; // The results of the evaluations - results: EvaluationComparisonResults | null; + results: Loadable; // The dimensions to compare & filter results comparisonDimensions?: ComparisonDimensionsType; // The current digest which is in view @@ -112,7 +112,7 @@ export const useEvaluationComparisonState = ( loading: false, result: { summary: summaryData.result, - results: resultsData.result, + results: resultsData, comparisonDimensions: newComparisonDimensions, selectedInputDigest, selectedMetrics, @@ -123,7 +123,7 @@ export const useEvaluationComparisonState = ( summaryData.result, summaryData.loading, comparisonDimensions, - resultsData.result, + resultsData, selectedInputDigest, selectedMetrics, evaluationCallIds, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts index 76f9ceff775..0cfec0b6f69 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/exampleCompareSectionUtil.ts @@ -124,13 +124,13 @@ export const useFilteredAggregateRows = (state: EvaluationComparisonState) => { const flattenedRows = useMemo(() => { const rows: FlattenedRow[] = []; - Object.entries(state.results?.resultRows ?? {}).forEach( + Object.entries(state.results.result?.resultRows ?? {}).forEach( ([rowDigest, rowCollection]) => { Object.values(rowCollection.evaluations).forEach(modelCollection => { Object.values(modelCollection.predictAndScores).forEach( predictAndScoreRes => { const datasetRow = - state.results?.inputs[predictAndScoreRes.rowDigest]; + state.results.result?.inputs[predictAndScoreRes.rowDigest]; if (datasetRow != null) { const output = predictAndScoreRes._rawPredictTraceData?.output; rows.push({ @@ -170,8 +170,8 @@ export const useFilteredAggregateRows = (state: EvaluationComparisonState) => { ); return rows; }, [ - state.results?.resultRows, - state.results?.inputs, + state.results.result?.resultRows, + state.results.result?.inputs, state.summary.scoreMetrics, ]); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx index 1c819c4ee16..13d590c35da 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx @@ -141,7 +141,7 @@ const SingleDimensionFilter: React.FC<{ ); if (baselineTargetDimension != null && compareTargetDimension != null) { - Object.entries(props.state.results?.resultRows ?? {}).forEach( + Object.entries(props.state.results.result?.resultRows ?? {}).forEach( ([digest, row]) => { const xVals: number[] = []; const yVals: number[] = []; @@ -232,7 +232,7 @@ const SingleDimensionFilter: React.FC<{ compareCallId, compositeMetricsMap, filteredDigest, - props.state.results?.resultRows, + props.state.results.result?.resultRows, targetDimension, ]); From f9ac9d85b989f77209d3fe38dde8b4659c9ab72b Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 20 Dec 2024 20:17:22 -0800 Subject: [PATCH 5/5] lint --- .../tsDataModelHooksEvaluationComparison.ts | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts index cb092a6c1d3..c321accd487 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/tsDataModelHooksEvaluationComparison.ts @@ -70,6 +70,7 @@ */ import {sum} from 'lodash'; +import _ from 'lodash'; import {useEffect, useMemo, useRef, useState} from 'react'; import {WB_RUN_COLORS} from '../../../../../../common/css/color.styles'; @@ -473,19 +474,26 @@ const fetchEvaluationComparisonResults = async ( const predictAndScoreIds = predictAndScoreCallRes.calls.map( call => call.id ); - return traceServerClient - .callsStreamQuery({ - project_id: projectId, - filter: {trace_ids: evalTraceIds, parent_ids: predictAndScoreIds}, + + return Promise.all( + _.chunk(predictAndScoreIds, 500).map(chunk => { + return traceServerClient + .callsStreamQuery({ + project_id: projectId, + filter: {trace_ids: evalTraceIds, parent_ids: chunk}, + }) + .then(predictionsAndScoresCallsRes => { + return predictionsAndScoresCallsRes.calls; + }); }) - .then(predictionsAndScoresCallsRes => { - return { - calls: [ - ...predictAndScoreCallRes.calls, - ...predictionsAndScoresCallsRes.calls, - ], - }; - }); + ).then(predictionsAndScoresCallsResMany => { + return { + calls: [ + ...predictAndScoreCallRes.calls, + ...predictionsAndScoresCallsResMany.flat(), + ], + }; + }); }); // 3.5 Populate the inputs