From 810f4a6bf2e6785fb167d369197321ede223bfb2 Mon Sep 17 00:00:00 2001 From: Shinwoo PARK Date: Mon, 11 May 2026 12:05:47 +0900 Subject: [PATCH] feat: add semantic fact search and embeddings --- src/core/identity-db.ts | 226 +++++++++++++++++++++++++++++++++- src/core/migrations.ts | 23 ++++ src/core/schema.ts | 11 ++ src/core/utils.ts | 38 +++++- src/ingestion/types.ts | 8 +- src/types/api.ts | 31 +++++ src/types/database.ts | 2 + src/types/domain.ts | 10 ++ tests/migrations.test.ts | 14 ++- tests/semantic-search.test.ts | 170 +++++++++++++++++++++++++ 10 files changed, 529 insertions(+), 4 deletions(-) create mode 100644 tests/semantic-search.test.ts diff --git a/src/core/identity-db.ts b/src/core/identity-db.ts index 6179daa..c802a19 100644 --- a/src/core/identity-db.ts +++ b/src/core/identity-db.ts @@ -2,7 +2,11 @@ import { type ConnectedTopic, type Fact, type FactTopic, + type FindSimilarFactsInput, + type IndexFactEmbeddingsInput, type ListTopicsOptions, + type ScoredFact, + type SearchFactsInput, type Topic, type TopicLookupOptions, type TopicWithFacts, @@ -19,11 +23,15 @@ import { IdentityDBError } from './errors'; import { initializeSchema } from './migrations'; import { canonicalizeTopicName, + cosineSimilarity, + createContentHash, createId, + deserializeEmbedding, mapFactRow, mapTopicRow, normalizeTopicName, nowIsoString, + serializeEmbedding, serializeMetadata, } from './utils'; import { extractFact } from '../ingestion/extractor'; @@ -155,7 +163,107 @@ export class IdentityDB { factInput.metadata = extracted.metadata; } - return this.addFact(factInput); + if (options.embeddingProvider) { + const similarFacts = await this.findSimilarFacts({ + statement: factInput.statement, + provider: options.embeddingProvider, + topicNames: factInput.topics.map((topic) => topic.name), + limit: 1, + minimumScore: options.duplicateThreshold ?? 0.97, + }); + + if (similarFacts[0]) { + return similarFacts[0]; + } + } + + const fact = await this.addFact(factInput); + + if (options.embeddingProvider) { + await this.indexFactEmbedding(fact.id, { provider: options.embeddingProvider }); + } + + return fact; + } + + async indexFactEmbeddings(input: IndexFactEmbeddingsInput): Promise { + const factRows = await this.connection.db.selectFrom('facts').selectAll().orderBy('created_at', 'asc').execute(); + + if (factRows.length === 0) { + return; + } + + const embeddings = input.provider.embedMany + ? await input.provider.embedMany(factRows.map((factRow) => factRow.statement)) + : await Promise.all(factRows.map((factRow) => input.provider.embed(factRow.statement))); + + if (embeddings.length !== factRows.length) { + throw new IdentityDBError('Embedding provider returned a mismatched number of embeddings.'); + } + + await this.connection.db.transaction().execute(async (trx) => { + for (let index = 0; index < factRows.length; index += 1) { + const factRow = factRows[index]!; + const embedding = embeddings[index]!; + this.assertEmbeddingShape(embedding, input.provider.dimensions); + await this.upsertFactEmbeddingRecord(trx, factRow.id, factRow.statement, embedding, input.provider.model); + } + }); + } + + async indexFactEmbedding(factId: string, input: IndexFactEmbeddingsInput): Promise { + const factRow = await this.connection.db + .selectFrom('facts') + .selectAll() + .where('id', '=', factId) + .executeTakeFirst(); + + if (!factRow) { + throw new IdentityDBError(`Fact not found: ${factId}`); + } + + const embedding = await input.provider.embed(factRow.statement); + this.assertEmbeddingShape(embedding, input.provider.dimensions); + + await this.connection.db.transaction().execute(async (trx) => { + await this.upsertFactEmbeddingRecord(trx, factRow.id, factRow.statement, embedding, input.provider.model); + }); + } + + async searchFacts(input: SearchFactsInput): Promise { + const queryText = input.query.trim(); + if (queryText.length === 0) { + return []; + } + + const queryEmbedding = await input.provider.embed(queryText); + this.assertEmbeddingShape(queryEmbedding, input.provider.dimensions); + + return this.searchFactsByEmbedding({ + providerModel: input.provider.model, + queryEmbedding, + topicNames: input.topicNames, + limit: input.limit, + minimumScore: input.minimumScore, + }); + } + + async findSimilarFacts(input: FindSimilarFactsInput): Promise { + const statement = input.statement.trim(); + if (statement.length === 0) { + return []; + } + + const queryEmbedding = await input.provider.embed(statement); + this.assertEmbeddingShape(queryEmbedding, input.provider.dimensions); + + return this.searchFactsByEmbedding({ + providerModel: input.provider.model, + queryEmbedding, + topicNames: input.topicNames, + limit: input.limit, + minimumScore: input.minimumScore, + }); } async linkTopics(input: LinkTopicsInput): Promise { @@ -413,6 +521,122 @@ export class IdentityDB { })); } + private async searchFactsByEmbedding(input: { + providerModel: string; + queryEmbedding: number[]; + topicNames?: string[] | undefined; + limit?: number | undefined; + minimumScore?: number | undefined; + }): Promise { + const topicIds = await this.resolveTopicIds(input.topicNames); + if (topicIds === null) { + return []; + } + + const factRows = topicIds.length > 0 + ? await findFactRowsConnectingTopicIds(this.connection.db, topicIds) + : await this.connection.db + .selectFrom('facts') + .innerJoin('fact_embeddings', 'fact_embeddings.fact_id', 'facts.id') + .selectAll('facts') + .where('fact_embeddings.model', '=', input.providerModel) + .orderBy('facts.created_at', 'asc') + .execute(); + + if (factRows.length === 0) { + return []; + } + + const embeddingRowsQuery = this.connection.db + .selectFrom('fact_embeddings') + .selectAll() + .where('model', '=', input.providerModel); + + const embeddingRows = factRows.length > 0 + ? await embeddingRowsQuery.where('fact_id', 'in', factRows.map((factRow) => factRow.id)).execute() + : []; + + const embeddingsByFactId = new Map( + embeddingRows.map((embeddingRow) => [embeddingRow.fact_id, deserializeEmbedding(embeddingRow.embedding)]), + ); + + const scoredRows = factRows + .map((factRow) => ({ + factRow, + score: cosineSimilarity(input.queryEmbedding, embeddingsByFactId.get(factRow.id) ?? []), + })) + .filter((entry) => entry.score >= (input.minimumScore ?? 0)) + .sort((left, right) => { + if (right.score !== left.score) { + return right.score - left.score; + } + return left.factRow.created_at.localeCompare(right.factRow.created_at); + }) + .slice(0, input.limit ?? 5); + + if (scoredRows.length === 0) { + return []; + } + + const hydratedFacts = await this.hydrateFacts(scoredRows.map((entry) => entry.factRow)); + const factsById = new Map(hydratedFacts.map((fact) => [fact.id, fact])); + + return scoredRows.map((entry) => ({ + ...factsById.get(entry.factRow.id)!, + score: entry.score, + })); + } + + private async resolveTopicIds(topicNames?: string[]): Promise { + if (!topicNames || topicNames.length === 0) { + return []; + } + + const topicRows = await Promise.all(topicNames.map((topicName) => this.getRequiredTopicRow(topicName))); + if (topicRows.some((topicRow) => !topicRow)) { + return null; + } + + return topicRows.map((topicRow) => topicRow!.id); + } + + private async upsertFactEmbeddingRecord( + executor: DatabaseExecutor, + factId: string, + statement: string, + embedding: number[], + model: string, + ): Promise { + const timestamp = nowIsoString(); + + await executor + .deleteFrom('fact_embeddings') + .where('fact_id', '=', factId) + .where('model', '=', model) + .execute(); + + await executor + .insertInto('fact_embeddings') + .values({ + fact_id: factId, + model, + dimensions: embedding.length, + embedding: serializeEmbedding(embedding), + content_hash: createContentHash(statement), + created_at: timestamp, + updated_at: timestamp, + }) + .execute(); + } + + private assertEmbeddingShape(embedding: number[], expectedDimensions: number): void { + if (embedding.length !== expectedDimensions) { + throw new IdentityDBError( + `Embedding dimension mismatch. Expected ${expectedDimensions}, received ${embedding.length}.`, + ); + } + } + private async upsertTopicInExecutor( executor: DatabaseExecutor, input: UpsertTopicInput, diff --git a/src/core/migrations.ts b/src/core/migrations.ts index 9c6b95a..2bc2398 100644 --- a/src/core/migrations.ts +++ b/src/core/migrations.ts @@ -2,6 +2,7 @@ import type { Kysely } from 'kysely'; import { FACTS_TABLE, + FACT_EMBEDDINGS_TABLE, FACT_TOPICS_TABLE, TOPIC_ALIASES_TABLE, TOPIC_RELATIONS_TABLE, @@ -39,6 +40,21 @@ export async function initializeSchema( .addColumn('updated_at', 'text', (column) => column.notNull()) .execute(); + await db.schema + .createTable(FACT_EMBEDDINGS_TABLE) + .ifNotExists() + .addColumn('fact_id', 'text', (column) => + column.notNull().references(`${FACTS_TABLE}.id`).onDelete('cascade'), + ) + .addColumn('model', 'text', (column) => column.notNull()) + .addColumn('dimensions', 'integer', (column) => column.notNull()) + .addColumn('embedding', 'text', (column) => column.notNull()) + .addColumn('content_hash', 'text', (column) => column.notNull()) + .addColumn('created_at', 'text', (column) => column.notNull()) + .addColumn('updated_at', 'text', (column) => column.notNull()) + .addPrimaryKeyConstraint('fact_embeddings_pk', ['fact_id', 'model']) + .execute(); + await db.schema .createTable(FACT_TOPICS_TABLE) .ifNotExists() @@ -96,6 +112,13 @@ export async function initializeSchema( .column('fact_id') .execute(); + await db.schema + .createIndex('fact_embeddings_model_idx') + .ifNotExists() + .on(FACT_EMBEDDINGS_TABLE) + .column('model') + .execute(); + await db.schema .createIndex('topic_relations_parent_topic_id_idx') .ifNotExists() diff --git a/src/core/schema.ts b/src/core/schema.ts index 0200234..5bbd72d 100644 --- a/src/core/schema.ts +++ b/src/core/schema.ts @@ -3,6 +3,7 @@ export const FACTS_TABLE = 'facts'; export const FACT_TOPICS_TABLE = 'fact_topics'; export const TOPIC_RELATIONS_TABLE = 'topic_relations'; export const TOPIC_ALIASES_TABLE = 'topic_aliases'; +export const FACT_EMBEDDINGS_TABLE = 'fact_embeddings'; export const TOPIC_COLUMNS = [ 'id', @@ -51,3 +52,13 @@ export const TOPIC_ALIAS_COLUMNS = [ 'created_at', 'updated_at', ] as const; + +export const FACT_EMBEDDING_COLUMNS = [ + 'fact_id', + 'model', + 'dimensions', + 'embedding', + 'content_hash', + 'created_at', + 'updated_at', +] as const; diff --git a/src/core/utils.ts b/src/core/utils.ts index 022f139..5cbc795 100644 --- a/src/core/utils.ts +++ b/src/core/utils.ts @@ -1,4 +1,4 @@ -import { randomUUID } from 'node:crypto'; +import { createHash, randomUUID } from 'node:crypto'; import type { Fact, FactTopic, Topic } from '../types/api'; import type { FactRecord, TopicRecord } from '../types/domain'; @@ -35,6 +35,42 @@ export function deserializeMetadata(metadata: string | null): unknown | null { return JSON.parse(metadata); } +export function serializeEmbedding(embedding: number[]): string { + return JSON.stringify(embedding); +} + +export function deserializeEmbedding(embedding: string): number[] { + return JSON.parse(embedding) as number[]; +} + +export function createContentHash(input: string): string { + return createHash('sha256').update(input).digest('hex'); +} + +export function cosineSimilarity(left: number[], right: number[]): number { + if (left.length === 0 || left.length !== right.length) { + return 0; + } + + let dot = 0; + let leftMagnitude = 0; + let rightMagnitude = 0; + + for (let index = 0; index < left.length; index += 1) { + const leftValue = left[index] ?? 0; + const rightValue = right[index] ?? 0; + dot += leftValue * rightValue; + leftMagnitude += leftValue * leftValue; + rightMagnitude += rightValue * rightValue; + } + + if (leftMagnitude === 0 || rightMagnitude === 0) { + return 0; + } + + return dot / (Math.sqrt(leftMagnitude) * Math.sqrt(rightMagnitude)); +} + export function mapTopicRow(record: TopicRecord): Topic { return { id: record.id, diff --git a/src/ingestion/types.ts b/src/ingestion/types.ts index d1add29..11aca5d 100644 --- a/src/ingestion/types.ts +++ b/src/ingestion/types.ts @@ -1,4 +1,8 @@ -import type { AddFactInput, TopicLinkInput } from '../types/api'; +import type { + AddFactInput, + EmbeddingProvider, + TopicLinkInput, +} from '../types/api'; export interface ExtractedFact { statement?: string; @@ -15,4 +19,6 @@ export interface FactExtractor { export interface IngestStatementOptions { extractor: FactExtractor; + embeddingProvider?: EmbeddingProvider; + duplicateThreshold?: number; } diff --git a/src/types/api.ts b/src/types/api.ts index 8ed5329..20cee77 100644 --- a/src/types/api.ts +++ b/src/types/api.ts @@ -71,3 +71,34 @@ export interface ListTopicsOptions { includeFacts?: boolean; limit?: number; } + +export interface EmbeddingProvider { + model: string; + dimensions: number; + embed(input: string): Promise; + embedMany?(inputs: string[]): Promise; +} + +export interface IndexFactEmbeddingsInput { + provider: EmbeddingProvider; +} + +export interface SearchFactsInput { + query: string; + provider: EmbeddingProvider; + topicNames?: string[]; + limit?: number; + minimumScore?: number; +} + +export interface FindSimilarFactsInput { + statement: string; + provider: EmbeddingProvider; + topicNames?: string[]; + limit?: number; + minimumScore?: number; +} + +export interface ScoredFact extends Fact { + score: number; +} diff --git a/src/types/database.ts b/src/types/database.ts index c9feeeb..8c9e8ae 100644 --- a/src/types/database.ts +++ b/src/types/database.ts @@ -1,4 +1,5 @@ import type { + FactEmbeddingRecord, FactRecord, FactTopicRecord, TopicAliasRecord, @@ -12,4 +13,5 @@ export interface IdentityDatabaseSchema { fact_topics: FactTopicRecord; topic_relations: TopicRelationRecord; topic_aliases: TopicAliasRecord; + fact_embeddings: FactEmbeddingRecord; } diff --git a/src/types/domain.ts b/src/types/domain.ts index 4be2cdb..ad94f4b 100644 --- a/src/types/domain.ts +++ b/src/types/domain.ts @@ -52,3 +52,13 @@ export interface TopicAliasRecord { created_at: string; updated_at: string; } + +export interface FactEmbeddingRecord { + fact_id: string; + model: string; + dimensions: number; + embedding: string; + content_hash: string; + created_at: string; + updated_at: string; +} diff --git a/tests/migrations.test.ts b/tests/migrations.test.ts index e119ef1..20b072e 100644 --- a/tests/migrations.test.ts +++ b/tests/migrations.test.ts @@ -16,7 +16,7 @@ afterEach(async () => { }); describe('initializeSchema', () => { - it('creates the topics, facts, fact_topics, topic_relations, and topic_aliases tables', async () => { + it('creates the topics, facts, fact_embeddings, fact_topics, topic_relations, and topic_aliases tables', async () => { const connection = await createDatabase({ client: 'sqlite', filename: ':memory:' }); openConnections.push(connection.destroy); @@ -33,6 +33,7 @@ describe('initializeSchema', () => { expect(tableNames).toContain('topics'); expect(tableNames).toContain('facts'); + expect(tableNames).toContain('fact_embeddings'); expect(tableNames).toContain('fact_topics'); expect(tableNames).toContain('topic_relations'); expect(tableNames).toContain('topic_aliases'); @@ -46,6 +47,7 @@ describe('initializeSchema', () => { const topicsColumns = await sql<{ name: string }>`PRAGMA table_info(topics)`.execute(connection.db); const factsColumns = await sql<{ name: string }>`PRAGMA table_info(facts)`.execute(connection.db); + const factEmbeddingsColumns = await sql<{ name: string }>`PRAGMA table_info(fact_embeddings)`.execute(connection.db); const factTopicsColumns = await sql<{ name: string }>`PRAGMA table_info(fact_topics)`.execute(connection.db); const topicRelationsColumns = await sql<{ name: string }>`PRAGMA table_info(topic_relations)`.execute(connection.db); const topicAliasesColumns = await sql<{ name: string }>`PRAGMA table_info(topic_aliases)`.execute(connection.db); @@ -73,6 +75,16 @@ describe('initializeSchema', () => { 'updated_at', ]); + expect(factEmbeddingsColumns.rows.map((row) => row.name)).toEqual([ + 'fact_id', + 'model', + 'dimensions', + 'embedding', + 'content_hash', + 'created_at', + 'updated_at', + ]); + expect(factTopicsColumns.rows.map((row) => row.name)).toEqual([ 'fact_id', 'topic_id', diff --git a/tests/semantic-search.test.ts b/tests/semantic-search.test.ts new file mode 100644 index 0000000..131324d --- /dev/null +++ b/tests/semantic-search.test.ts @@ -0,0 +1,170 @@ +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { IdentityDB } from '../src/core/identity-db'; +import type { FactExtractor } from '../src/ingestion/types'; +import type { EmbeddingProvider } from '../src/types/api'; + +class FakeEmbeddingProvider implements EmbeddingProvider { + model = 'fake-semantic-v1'; + dimensions = 3; + + async embed(input: string): Promise { + return embeddingFor(input); + } + + async embedMany(inputs: string[]): Promise { + return Promise.all(inputs.map((input) => this.embed(input))); + } +} + +function embeddingFor(input: string): number[] { + const normalized = input.toLowerCase(); + + if (normalized.includes('bun') && normalized.includes('typescript')) { + return [1, 0, 0]; + } + + if (normalized.includes('tooling') || normalized.includes('runtime')) { + return [0.98, 0.02, 0]; + } + + if (normalized.includes('typescript')) { + return [0.9, 0.1, 0]; + } + + if (normalized.includes('python')) { + return [0, 1, 0]; + } + + if (normalized.includes('database')) { + return [0, 0.2, 0.8]; + } + + return [0.1, 0.1, 0.1]; +} + +describe('IdentityDB semantic search', () => { + let db: IdentityDB; + let provider: FakeEmbeddingProvider; + + beforeEach(async () => { + provider = new FakeEmbeddingProvider(); + db = await IdentityDB.connect({ client: 'sqlite', filename: ':memory:' }); + await db.initialize(); + + await db.addFact({ + statement: 'Bun runs TypeScript tooling quickly.', + topics: [ + { name: 'Bun', category: 'entity', granularity: 'concrete' }, + { name: 'TypeScript', category: 'entity', granularity: 'concrete' }, + ], + }); + + await db.addFact({ + statement: 'TypeScript compiles to JavaScript.', + topics: [ + { name: 'TypeScript', category: 'entity', granularity: 'concrete' }, + { name: 'JavaScript', category: 'entity', granularity: 'concrete' }, + ], + }); + + await db.addFact({ + statement: 'Python uses indentation syntax.', + topics: [ + { name: 'Python', category: 'entity', granularity: 'concrete' }, + ], + }); + }); + + afterEach(async () => { + await db.close(); + }); + + it('indexes facts and returns semantic search matches ordered by score', async () => { + await db.indexFactEmbeddings({ provider }); + + const matches = await db.searchFacts({ + query: 'TypeScript runtime tooling', + provider, + limit: 2, + }); + + expect(matches).toHaveLength(2); + expect(matches[0]?.statement).toBe('Bun runs TypeScript tooling quickly.'); + expect(matches[1]?.statement).toBe('TypeScript compiles to JavaScript.'); + expect(matches[0]!.score).toBeGreaterThan(matches[1]!.score); + }); + + it('filters semantic search candidates by topic names', async () => { + await db.indexFactEmbeddings({ provider }); + + const matches = await db.searchFacts({ + query: 'TypeScript runtime tooling', + provider, + topicNames: ['Python'], + limit: 5, + }); + + expect(matches.map((match) => match.statement)).toEqual(['Python uses indentation syntax.']); + }); + + it('finds similar facts from an input statement', async () => { + await db.indexFactEmbeddings({ provider }); + + const matches = await db.findSimilarFacts({ + statement: 'Bun makes TypeScript tooling fast.', + provider, + limit: 2, + }); + + expect(matches[0]?.statement).toBe('Bun runs TypeScript tooling quickly.'); + expect(matches[0]!.score).toBeGreaterThan(matches[1]!.score); + }); +}); + +describe('IdentityDB dedup-aware ingestion', () => { + let db: IdentityDB; + let provider: FakeEmbeddingProvider; + let extractor: FactExtractor; + + beforeEach(async () => { + provider = new FakeEmbeddingProvider(); + extractor = { + async extract(input) { + return { + statement: input, + topics: [ + { name: 'Bun', category: 'entity', granularity: 'concrete' }, + { name: 'TypeScript', category: 'entity', granularity: 'concrete' }, + ], + }; + }, + }; + + db = await IdentityDB.connect({ client: 'sqlite', filename: ':memory:' }); + await db.initialize(); + }); + + afterEach(async () => { + await db.close(); + }); + + it('returns the existing fact when ingestion detects a semantic duplicate', async () => { + const first = await db.ingestStatement('Bun runs TypeScript tooling quickly.', { + extractor, + embeddingProvider: provider, + }); + + const second = await db.ingestStatement('Bun makes TypeScript tooling fast.', { + extractor, + embeddingProvider: provider, + duplicateThreshold: 0.95, + }); + + const facts = await db.getTopicFacts('TypeScript'); + + expect(second.id).toBe(first.id); + expect(facts).toHaveLength(1); + expect(facts[0]?.statement).toBe('Bun runs TypeScript tooling quickly.'); + }); +});