Skip to content

Commit a17627a

Browse files
authored
Merge pull request #3 from unicef/nicpottier/image-classification
Add image classification, metadata extraction, and LLM pipeline
2 parents ce5d958 + c4ad002 commit a17627a

28 files changed

Lines changed: 826 additions & 119 deletions

config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ text_group_types:
2727
list: A list of items (ordered or unordered)
2828
other: Anything that doesn't fit the above
2929

30+
metadata:
31+
prompt: metadata_extraction
32+
model: openai:gpt-4o
33+
3034
text_classification:
3135
prompt: text_classification
3236
model: openai:gpt-4o
@@ -36,3 +40,7 @@ pruned_text_types:
3640
- header_text
3741
- footer_text
3842
- page_number
43+
44+
image_filters:
45+
min_side: 100
46+
max_side: 5000

packages/llm/src/client.ts

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ import type {
77
Message,
88
TokenUsage,
99
} from "./types.js"
10+
import type { PromptEngine } from "./prompt.js"
1011
import { computeHash, readCache, writeCache, bustCache } from "./cache.js"
1112
import { sanitizeMessages, type LlmLogEntry } from "./log.js"
1213

1314
export interface CreateLLMModelOptions {
1415
modelId: string // "openai:gpt-4o" format
1516
cacheDir?: string
17+
promptEngine?: PromptEngine
1618
onLog?: (entry: LlmLogEntry) => void
1719
}
1820

@@ -23,27 +25,48 @@ export interface CreateLLMModelOptions {
2325
* - Disk-based response caching (SHA-256 hash of inputs)
2426
* - Validation with retry loops
2527
* - Structured logging (images replaced with hash placeholders)
28+
* - Optional prompt rendering (pass promptEngine + use prompt option)
2629
*/
2730
export function createLLMModel(options: CreateLLMModelOptions): LLMModel {
28-
const { modelId, cacheDir, onLog } = options
31+
const { modelId, cacheDir, promptEngine, onLog } = options
2932
const languageModel = resolveModel(modelId)
3033

3134
return {
3235
async generateObject<T>(
3336
opts: GenerateObjectOptions
3437
): Promise<GenerateObjectResult<T>> {
38+
// Resolve prompt to system + messages if needed
39+
let system = opts.system
40+
let messages = opts.messages ?? []
41+
42+
if (opts.prompt) {
43+
if (!promptEngine) {
44+
throw new Error("promptEngine required when using prompt option")
45+
}
46+
const allMessages = await promptEngine.renderPrompt(
47+
opts.prompt.name,
48+
opts.prompt.context
49+
)
50+
const systemMsg = allMessages.find((m) => m.role === "system")
51+
system =
52+
typeof systemMsg?.content === "string"
53+
? systemMsg.content
54+
: undefined
55+
messages = allMessages.filter((m) => m.role !== "system")
56+
}
57+
3558
const maxRetries = opts.maxRetries ?? 0
3659
const t0 = Date.now()
3760

38-
let currentMessages = opts.messages
61+
let currentMessages = messages
3962
let allErrors: string[] = []
4063
let lastCacheHit = false
4164
let totalUsage: TokenUsage = { inputTokens: 0, outputTokens: 0 }
4265

4366
for (let attempt = 0; attempt <= maxRetries; attempt++) {
4467
const hash = computeHash({
4568
modelId,
46-
system: opts.system,
69+
system,
4770
messages: currentMessages,
4871
schema: opts.schema,
4972
})
@@ -61,6 +84,7 @@ export function createLLMModel(options: CreateLLMModelOptions): LLMModel {
6184
const generated = await callLLM<T>(
6285
languageModel,
6386
opts,
87+
system,
6488
currentMessages
6589
)
6690
result = generated.object
@@ -73,6 +97,7 @@ export function createLLMModel(options: CreateLLMModelOptions): LLMModel {
7397
const generated = await callLLM<T>(
7498
languageModel,
7599
opts,
100+
system,
76101
currentMessages
77102
)
78103
result = generated.object
@@ -112,7 +137,7 @@ export function createLLMModel(options: CreateLLMModelOptions): LLMModel {
112137
? totalUsage
113138
: undefined,
114139
validationErrors: allErrors.length > 0 ? allErrors : undefined,
115-
system: opts.system,
140+
system,
116141
messages: sanitizeMessages(currentMessages),
117142
})
118143
}
@@ -143,7 +168,7 @@ export function createLLMModel(options: CreateLLMModelOptions): LLMModel {
143168
? totalUsage
144169
: undefined,
145170
validationErrors: allErrors,
146-
system: opts.system,
171+
system,
147172
messages: sanitizeMessages(currentMessages),
148173
})
149174
}
@@ -175,13 +200,14 @@ function resolveModel(modelId: string): LanguageModel {
175200
async function callLLM<T>(
176201
model: LanguageModel,
177202
opts: GenerateObjectOptions,
203+
system: string | undefined,
178204
messages: Message[]
179205
): Promise<{ object: T; usage: TokenUsage }> {
180206
const coreMessages = convertMessages(messages)
181207
const generateOpts: Record<string, unknown> = {
182208
model,
183209
schema: opts.schema,
184-
system: opts.system,
210+
system,
185211
messages: coreMessages,
186212
}
187213
if (opts.maxTokens) {

packages/llm/src/types.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@ export interface LLMModel {
44

55
export interface GenerateObjectOptions {
66
schema: unknown
7+
8+
/** Provide either prompt (rendered via prompt engine) or system + messages directly */
9+
prompt?: { name: string; context: Record<string, unknown> }
710
system?: string
8-
messages: Message[]
11+
messages?: Message[]
12+
913
validate?: (result: unknown) => ValidationResult
1014
maxRetries?: number
1115
maxTokens?: number
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import { describe, it, expect } from "vitest"
2+
import { classifyPageImages, buildImageClassifyConfig } from "../image-classification.js"
3+
import type { ImageClassifyConfig } from "../image-classification.js"
4+
import type { ImageData } from "@adt/storage"
5+
import type { AppConfig } from "@adt/types"
6+
7+
function makeImage(imageId: string, width: number, height: number): ImageData {
8+
return { imageId, width, height }
9+
}
10+
11+
describe("classifyPageImages", () => {
12+
const defaultConfig: ImageClassifyConfig = {
13+
filters: { min_side: 100, max_side: 5000 },
14+
}
15+
16+
it("prunes full-page renders", () => {
17+
const images = [makeImage("pg001_page", 800, 600)]
18+
const result = classifyPageImages("pg001", images, defaultConfig)
19+
20+
expect(result.images).toHaveLength(1)
21+
expect(result.images[0]).toEqual({
22+
imageId: "pg001_page",
23+
isPruned: true,
24+
reason: "full-page render",
25+
})
26+
})
27+
28+
it("keeps images within size bounds", () => {
29+
const images = [makeImage("pg001_im001", 400, 300)]
30+
const result = classifyPageImages("pg001", images, defaultConfig)
31+
32+
expect(result.images).toHaveLength(1)
33+
expect(result.images[0]).toEqual({
34+
imageId: "pg001_im001",
35+
isPruned: false,
36+
})
37+
})
38+
39+
it("prunes images with shortest side below min_side", () => {
40+
const images = [makeImage("pg001_im001", 200, 50)]
41+
const result = classifyPageImages("pg001", images, defaultConfig)
42+
43+
expect(result.images).toHaveLength(1)
44+
expect(result.images[0]).toEqual({
45+
imageId: "pg001_im001",
46+
isPruned: true,
47+
reason: "shortest side 50px < min_side 100px",
48+
})
49+
})
50+
51+
it("prunes images with longest side above max_side", () => {
52+
const images = [makeImage("pg001_im001", 6000, 3000)]
53+
const result = classifyPageImages("pg001", images, defaultConfig)
54+
55+
expect(result.images).toHaveLength(1)
56+
expect(result.images[0]).toEqual({
57+
imageId: "pg001_im001",
58+
isPruned: true,
59+
reason: "longest side 6000px > max_side 5000px",
60+
})
61+
})
62+
63+
it("checks min_side before max_side", () => {
64+
// Image with both short side < min and long side > max: min_side triggers first
65+
const images = [makeImage("pg001_im001", 6000, 50)]
66+
const result = classifyPageImages("pg001", images, defaultConfig)
67+
68+
expect(result.images[0].reason).toContain("min_side")
69+
})
70+
71+
it("handles empty image list", () => {
72+
const result = classifyPageImages("pg001", [], defaultConfig)
73+
expect(result.images).toHaveLength(0)
74+
})
75+
76+
it("classifies multiple images per page", () => {
77+
const images = [
78+
makeImage("pg001_page", 800, 600),
79+
makeImage("pg001_im001", 400, 300),
80+
makeImage("pg001_im002", 20, 20),
81+
makeImage("pg001_im003", 6000, 4000),
82+
]
83+
const result = classifyPageImages("pg001", images, defaultConfig)
84+
85+
expect(result.images).toHaveLength(4)
86+
expect(result.images[0].isPruned).toBe(true) // full-page render
87+
expect(result.images[1].isPruned).toBe(false) // good size
88+
expect(result.images[2].isPruned).toBe(true) // too small
89+
expect(result.images[3].isPruned).toBe(true) // too big
90+
})
91+
92+
it("skips min_side check when not configured", () => {
93+
const config: ImageClassifyConfig = { filters: { max_side: 5000 } }
94+
const images = [makeImage("pg001_im001", 10, 5)]
95+
const result = classifyPageImages("pg001", images, config)
96+
97+
expect(result.images[0].isPruned).toBe(false)
98+
})
99+
100+
it("skips max_side check when not configured", () => {
101+
const config: ImageClassifyConfig = { filters: { min_side: 100 } }
102+
const images = [makeImage("pg001_im001", 10000, 8000)]
103+
const result = classifyPageImages("pg001", images, config)
104+
105+
expect(result.images[0].isPruned).toBe(false)
106+
})
107+
108+
it("keeps all non-page images when no filters configured", () => {
109+
const config: ImageClassifyConfig = { filters: {} }
110+
const images = [
111+
makeImage("pg001_page", 800, 600),
112+
makeImage("pg001_im001", 10, 5),
113+
makeImage("pg001_im002", 10000, 8000),
114+
]
115+
const result = classifyPageImages("pg001", images, config)
116+
117+
expect(result.images[0].isPruned).toBe(true) // page render always pruned
118+
expect(result.images[1].isPruned).toBe(false)
119+
expect(result.images[2].isPruned).toBe(false)
120+
})
121+
122+
it("uses min of width/height for min_side check", () => {
123+
const config: ImageClassifyConfig = { filters: { min_side: 100 } }
124+
125+
// Portrait: width=80, height=200 → shortSide=80 < 100 → pruned
126+
const portrait = classifyPageImages("pg001", [makeImage("pg001_im001", 80, 200)], config)
127+
expect(portrait.images[0].isPruned).toBe(true)
128+
129+
// Landscape: width=200, height=80 → shortSide=80 < 100 → pruned
130+
const landscape = classifyPageImages("pg001", [makeImage("pg001_im001", 200, 80)], config)
131+
expect(landscape.images[0].isPruned).toBe(true)
132+
133+
// Both sides above: width=150, height=120 → shortSide=120 ≥ 100 → kept
134+
const ok = classifyPageImages("pg001", [makeImage("pg001_im001", 150, 120)], config)
135+
expect(ok.images[0].isPruned).toBe(false)
136+
})
137+
138+
it("uses max of width/height for max_side check", () => {
139+
const config: ImageClassifyConfig = { filters: { max_side: 5000 } }
140+
141+
// Portrait: width=3000, height=6000 → longSide=6000 > 5000 → pruned
142+
const portrait = classifyPageImages("pg001", [makeImage("pg001_im001", 3000, 6000)], config)
143+
expect(portrait.images[0].isPruned).toBe(true)
144+
145+
// Landscape: width=6000, height=3000 → longSide=6000 > 5000 → pruned
146+
const landscape = classifyPageImages("pg001", [makeImage("pg001_im001", 6000, 3000)], config)
147+
expect(landscape.images[0].isPruned).toBe(true)
148+
149+
// Both sides below: width=4000, height=3000 → longSide=4000 ≤ 5000 → kept
150+
const ok = classifyPageImages("pg001", [makeImage("pg001_im001", 4000, 3000)], config)
151+
expect(ok.images[0].isPruned).toBe(false)
152+
})
153+
154+
it("handles edge case: image exactly at min_side boundary", () => {
155+
const config: ImageClassifyConfig = { filters: { min_side: 100 } }
156+
const images = [makeImage("pg001_im001", 100, 200)]
157+
const result = classifyPageImages("pg001", images, config)
158+
159+
// Exactly 100 is not < 100, so kept
160+
expect(result.images[0].isPruned).toBe(false)
161+
})
162+
163+
it("handles edge case: image exactly at max_side boundary", () => {
164+
const config: ImageClassifyConfig = { filters: { max_side: 5000 } }
165+
const images = [makeImage("pg001_im001", 5000, 3000)]
166+
const result = classifyPageImages("pg001", images, config)
167+
168+
// Exactly 5000 is not > 5000, so kept
169+
expect(result.images[0].isPruned).toBe(false)
170+
})
171+
})
172+
173+
describe("buildImageClassifyConfig", () => {
174+
it("extracts image_filters from AppConfig", () => {
175+
const appConfig: AppConfig = {
176+
text_types: { heading: "Heading" },
177+
text_group_types: { paragraph: "Paragraph" },
178+
image_filters: { min_side: 50, max_side: 3000 },
179+
}
180+
181+
const config = buildImageClassifyConfig(appConfig)
182+
expect(config.filters).toEqual({ min_side: 50, max_side: 3000 })
183+
})
184+
185+
it("defaults to empty filters when image_filters not set", () => {
186+
const appConfig: AppConfig = {
187+
text_types: { heading: "Heading" },
188+
text_group_types: { paragraph: "Paragraph" },
189+
}
190+
191+
const config = buildImageClassifyConfig(appConfig)
192+
expect(config.filters).toEqual({})
193+
})
194+
})

0 commit comments

Comments
 (0)