feat: add semantic fact search and embeddings

This commit is contained in:
2026-05-11 12:05:47 +09:00
parent 428f5021e8
commit 810f4a6bf2
10 changed files with 529 additions and 4 deletions

View File

@@ -2,7 +2,11 @@ import {
type ConnectedTopic, type ConnectedTopic,
type Fact, type Fact,
type FactTopic, type FactTopic,
type FindSimilarFactsInput,
type IndexFactEmbeddingsInput,
type ListTopicsOptions, type ListTopicsOptions,
type ScoredFact,
type SearchFactsInput,
type Topic, type Topic,
type TopicLookupOptions, type TopicLookupOptions,
type TopicWithFacts, type TopicWithFacts,
@@ -19,11 +23,15 @@ import { IdentityDBError } from './errors';
import { initializeSchema } from './migrations'; import { initializeSchema } from './migrations';
import { import {
canonicalizeTopicName, canonicalizeTopicName,
cosineSimilarity,
createContentHash,
createId, createId,
deserializeEmbedding,
mapFactRow, mapFactRow,
mapTopicRow, mapTopicRow,
normalizeTopicName, normalizeTopicName,
nowIsoString, nowIsoString,
serializeEmbedding,
serializeMetadata, serializeMetadata,
} from './utils'; } from './utils';
import { extractFact } from '../ingestion/extractor'; import { extractFact } from '../ingestion/extractor';
@@ -155,7 +163,107 @@ export class IdentityDB {
factInput.metadata = extracted.metadata; factInput.metadata = extracted.metadata;
} }
return this.addFact(factInput); if (options.embeddingProvider) {
const similarFacts = await this.findSimilarFacts({
statement: factInput.statement,
provider: options.embeddingProvider,
topicNames: factInput.topics.map((topic) => topic.name),
limit: 1,
minimumScore: options.duplicateThreshold ?? 0.97,
});
if (similarFacts[0]) {
return similarFacts[0];
}
}
const fact = await this.addFact(factInput);
if (options.embeddingProvider) {
await this.indexFactEmbedding(fact.id, { provider: options.embeddingProvider });
}
return fact;
}
async indexFactEmbeddings(input: IndexFactEmbeddingsInput): Promise<void> {
const factRows = await this.connection.db.selectFrom('facts').selectAll().orderBy('created_at', 'asc').execute();
if (factRows.length === 0) {
return;
}
const embeddings = input.provider.embedMany
? await input.provider.embedMany(factRows.map((factRow) => factRow.statement))
: await Promise.all(factRows.map((factRow) => input.provider.embed(factRow.statement)));
if (embeddings.length !== factRows.length) {
throw new IdentityDBError('Embedding provider returned a mismatched number of embeddings.');
}
await this.connection.db.transaction().execute(async (trx) => {
for (let index = 0; index < factRows.length; index += 1) {
const factRow = factRows[index]!;
const embedding = embeddings[index]!;
this.assertEmbeddingShape(embedding, input.provider.dimensions);
await this.upsertFactEmbeddingRecord(trx, factRow.id, factRow.statement, embedding, input.provider.model);
}
});
}
async indexFactEmbedding(factId: string, input: IndexFactEmbeddingsInput): Promise<void> {
const factRow = await this.connection.db
.selectFrom('facts')
.selectAll()
.where('id', '=', factId)
.executeTakeFirst();
if (!factRow) {
throw new IdentityDBError(`Fact not found: ${factId}`);
}
const embedding = await input.provider.embed(factRow.statement);
this.assertEmbeddingShape(embedding, input.provider.dimensions);
await this.connection.db.transaction().execute(async (trx) => {
await this.upsertFactEmbeddingRecord(trx, factRow.id, factRow.statement, embedding, input.provider.model);
});
}
async searchFacts(input: SearchFactsInput): Promise<ScoredFact[]> {
const queryText = input.query.trim();
if (queryText.length === 0) {
return [];
}
const queryEmbedding = await input.provider.embed(queryText);
this.assertEmbeddingShape(queryEmbedding, input.provider.dimensions);
return this.searchFactsByEmbedding({
providerModel: input.provider.model,
queryEmbedding,
topicNames: input.topicNames,
limit: input.limit,
minimumScore: input.minimumScore,
});
}
async findSimilarFacts(input: FindSimilarFactsInput): Promise<ScoredFact[]> {
const statement = input.statement.trim();
if (statement.length === 0) {
return [];
}
const queryEmbedding = await input.provider.embed(statement);
this.assertEmbeddingShape(queryEmbedding, input.provider.dimensions);
return this.searchFactsByEmbedding({
providerModel: input.provider.model,
queryEmbedding,
topicNames: input.topicNames,
limit: input.limit,
minimumScore: input.minimumScore,
});
} }
async linkTopics(input: LinkTopicsInput): Promise<void> { async linkTopics(input: LinkTopicsInput): Promise<void> {
@@ -413,6 +521,122 @@ export class IdentityDB {
})); }));
} }
private async searchFactsByEmbedding(input: {
providerModel: string;
queryEmbedding: number[];
topicNames?: string[] | undefined;
limit?: number | undefined;
minimumScore?: number | undefined;
}): Promise<ScoredFact[]> {
const topicIds = await this.resolveTopicIds(input.topicNames);
if (topicIds === null) {
return [];
}
const factRows = topicIds.length > 0
? await findFactRowsConnectingTopicIds(this.connection.db, topicIds)
: await this.connection.db
.selectFrom('facts')
.innerJoin('fact_embeddings', 'fact_embeddings.fact_id', 'facts.id')
.selectAll('facts')
.where('fact_embeddings.model', '=', input.providerModel)
.orderBy('facts.created_at', 'asc')
.execute();
if (factRows.length === 0) {
return [];
}
const embeddingRowsQuery = this.connection.db
.selectFrom('fact_embeddings')
.selectAll()
.where('model', '=', input.providerModel);
const embeddingRows = factRows.length > 0
? await embeddingRowsQuery.where('fact_id', 'in', factRows.map((factRow) => factRow.id)).execute()
: [];
const embeddingsByFactId = new Map(
embeddingRows.map((embeddingRow) => [embeddingRow.fact_id, deserializeEmbedding(embeddingRow.embedding)]),
);
const scoredRows = factRows
.map((factRow) => ({
factRow,
score: cosineSimilarity(input.queryEmbedding, embeddingsByFactId.get(factRow.id) ?? []),
}))
.filter((entry) => entry.score >= (input.minimumScore ?? 0))
.sort((left, right) => {
if (right.score !== left.score) {
return right.score - left.score;
}
return left.factRow.created_at.localeCompare(right.factRow.created_at);
})
.slice(0, input.limit ?? 5);
if (scoredRows.length === 0) {
return [];
}
const hydratedFacts = await this.hydrateFacts(scoredRows.map((entry) => entry.factRow));
const factsById = new Map(hydratedFacts.map((fact) => [fact.id, fact]));
return scoredRows.map((entry) => ({
...factsById.get(entry.factRow.id)!,
score: entry.score,
}));
}
private async resolveTopicIds(topicNames?: string[]): Promise<string[] | null> {
if (!topicNames || topicNames.length === 0) {
return [];
}
const topicRows = await Promise.all(topicNames.map((topicName) => this.getRequiredTopicRow(topicName)));
if (topicRows.some((topicRow) => !topicRow)) {
return null;
}
return topicRows.map((topicRow) => topicRow!.id);
}
private async upsertFactEmbeddingRecord(
executor: DatabaseExecutor,
factId: string,
statement: string,
embedding: number[],
model: string,
): Promise<void> {
const timestamp = nowIsoString();
await executor
.deleteFrom('fact_embeddings')
.where('fact_id', '=', factId)
.where('model', '=', model)
.execute();
await executor
.insertInto('fact_embeddings')
.values({
fact_id: factId,
model,
dimensions: embedding.length,
embedding: serializeEmbedding(embedding),
content_hash: createContentHash(statement),
created_at: timestamp,
updated_at: timestamp,
})
.execute();
}
private assertEmbeddingShape(embedding: number[], expectedDimensions: number): void {
if (embedding.length !== expectedDimensions) {
throw new IdentityDBError(
`Embedding dimension mismatch. Expected ${expectedDimensions}, received ${embedding.length}.`,
);
}
}
private async upsertTopicInExecutor( private async upsertTopicInExecutor(
executor: DatabaseExecutor, executor: DatabaseExecutor,
input: UpsertTopicInput, input: UpsertTopicInput,

View File

@@ -2,6 +2,7 @@ import type { Kysely } from 'kysely';
import { import {
FACTS_TABLE, FACTS_TABLE,
FACT_EMBEDDINGS_TABLE,
FACT_TOPICS_TABLE, FACT_TOPICS_TABLE,
TOPIC_ALIASES_TABLE, TOPIC_ALIASES_TABLE,
TOPIC_RELATIONS_TABLE, TOPIC_RELATIONS_TABLE,
@@ -39,6 +40,21 @@ export async function initializeSchema(
.addColumn('updated_at', 'text', (column) => column.notNull()) .addColumn('updated_at', 'text', (column) => column.notNull())
.execute(); .execute();
await db.schema
.createTable(FACT_EMBEDDINGS_TABLE)
.ifNotExists()
.addColumn('fact_id', 'text', (column) =>
column.notNull().references(`${FACTS_TABLE}.id`).onDelete('cascade'),
)
.addColumn('model', 'text', (column) => column.notNull())
.addColumn('dimensions', 'integer', (column) => column.notNull())
.addColumn('embedding', 'text', (column) => column.notNull())
.addColumn('content_hash', 'text', (column) => column.notNull())
.addColumn('created_at', 'text', (column) => column.notNull())
.addColumn('updated_at', 'text', (column) => column.notNull())
.addPrimaryKeyConstraint('fact_embeddings_pk', ['fact_id', 'model'])
.execute();
await db.schema await db.schema
.createTable(FACT_TOPICS_TABLE) .createTable(FACT_TOPICS_TABLE)
.ifNotExists() .ifNotExists()
@@ -96,6 +112,13 @@ export async function initializeSchema(
.column('fact_id') .column('fact_id')
.execute(); .execute();
await db.schema
.createIndex('fact_embeddings_model_idx')
.ifNotExists()
.on(FACT_EMBEDDINGS_TABLE)
.column('model')
.execute();
await db.schema await db.schema
.createIndex('topic_relations_parent_topic_id_idx') .createIndex('topic_relations_parent_topic_id_idx')
.ifNotExists() .ifNotExists()

View File

@@ -3,6 +3,7 @@ export const FACTS_TABLE = 'facts';
export const FACT_TOPICS_TABLE = 'fact_topics'; export const FACT_TOPICS_TABLE = 'fact_topics';
export const TOPIC_RELATIONS_TABLE = 'topic_relations'; export const TOPIC_RELATIONS_TABLE = 'topic_relations';
export const TOPIC_ALIASES_TABLE = 'topic_aliases'; export const TOPIC_ALIASES_TABLE = 'topic_aliases';
export const FACT_EMBEDDINGS_TABLE = 'fact_embeddings';
export const TOPIC_COLUMNS = [ export const TOPIC_COLUMNS = [
'id', 'id',
@@ -51,3 +52,13 @@ export const TOPIC_ALIAS_COLUMNS = [
'created_at', 'created_at',
'updated_at', 'updated_at',
] as const; ] as const;
export const FACT_EMBEDDING_COLUMNS = [
'fact_id',
'model',
'dimensions',
'embedding',
'content_hash',
'created_at',
'updated_at',
] as const;

View File

@@ -1,4 +1,4 @@
import { randomUUID } from 'node:crypto'; import { createHash, randomUUID } from 'node:crypto';
import type { Fact, FactTopic, Topic } from '../types/api'; import type { Fact, FactTopic, Topic } from '../types/api';
import type { FactRecord, TopicRecord } from '../types/domain'; import type { FactRecord, TopicRecord } from '../types/domain';
@@ -35,6 +35,42 @@ export function deserializeMetadata(metadata: string | null): unknown | null {
return JSON.parse(metadata); return JSON.parse(metadata);
} }
export function serializeEmbedding(embedding: number[]): string {
return JSON.stringify(embedding);
}
export function deserializeEmbedding(embedding: string): number[] {
return JSON.parse(embedding) as number[];
}
export function createContentHash(input: string): string {
return createHash('sha256').update(input).digest('hex');
}
export function cosineSimilarity(left: number[], right: number[]): number {
if (left.length === 0 || left.length !== right.length) {
return 0;
}
let dot = 0;
let leftMagnitude = 0;
let rightMagnitude = 0;
for (let index = 0; index < left.length; index += 1) {
const leftValue = left[index] ?? 0;
const rightValue = right[index] ?? 0;
dot += leftValue * rightValue;
leftMagnitude += leftValue * leftValue;
rightMagnitude += rightValue * rightValue;
}
if (leftMagnitude === 0 || rightMagnitude === 0) {
return 0;
}
return dot / (Math.sqrt(leftMagnitude) * Math.sqrt(rightMagnitude));
}
export function mapTopicRow(record: TopicRecord): Topic { export function mapTopicRow(record: TopicRecord): Topic {
return { return {
id: record.id, id: record.id,

View File

@@ -1,4 +1,8 @@
import type { AddFactInput, TopicLinkInput } from '../types/api'; import type {
AddFactInput,
EmbeddingProvider,
TopicLinkInput,
} from '../types/api';
export interface ExtractedFact { export interface ExtractedFact {
statement?: string; statement?: string;
@@ -15,4 +19,6 @@ export interface FactExtractor {
export interface IngestStatementOptions { export interface IngestStatementOptions {
extractor: FactExtractor; extractor: FactExtractor;
embeddingProvider?: EmbeddingProvider;
duplicateThreshold?: number;
} }

View File

@@ -71,3 +71,34 @@ export interface ListTopicsOptions {
includeFacts?: boolean; includeFacts?: boolean;
limit?: number; limit?: number;
} }
export interface EmbeddingProvider {
model: string;
dimensions: number;
embed(input: string): Promise<number[]>;
embedMany?(inputs: string[]): Promise<number[][]>;
}
export interface IndexFactEmbeddingsInput {
provider: EmbeddingProvider;
}
export interface SearchFactsInput {
query: string;
provider: EmbeddingProvider;
topicNames?: string[];
limit?: number;
minimumScore?: number;
}
export interface FindSimilarFactsInput {
statement: string;
provider: EmbeddingProvider;
topicNames?: string[];
limit?: number;
minimumScore?: number;
}
export interface ScoredFact extends Fact {
score: number;
}

View File

@@ -1,4 +1,5 @@
import type { import type {
FactEmbeddingRecord,
FactRecord, FactRecord,
FactTopicRecord, FactTopicRecord,
TopicAliasRecord, TopicAliasRecord,
@@ -12,4 +13,5 @@ export interface IdentityDatabaseSchema {
fact_topics: FactTopicRecord; fact_topics: FactTopicRecord;
topic_relations: TopicRelationRecord; topic_relations: TopicRelationRecord;
topic_aliases: TopicAliasRecord; topic_aliases: TopicAliasRecord;
fact_embeddings: FactEmbeddingRecord;
} }

View File

@@ -52,3 +52,13 @@ export interface TopicAliasRecord {
created_at: string; created_at: string;
updated_at: string; updated_at: string;
} }
export interface FactEmbeddingRecord {
fact_id: string;
model: string;
dimensions: number;
embedding: string;
content_hash: string;
created_at: string;
updated_at: string;
}

View File

@@ -16,7 +16,7 @@ afterEach(async () => {
}); });
describe('initializeSchema', () => { describe('initializeSchema', () => {
it('creates the topics, facts, fact_topics, topic_relations, and topic_aliases tables', async () => { it('creates the topics, facts, fact_embeddings, fact_topics, topic_relations, and topic_aliases tables', async () => {
const connection = await createDatabase({ client: 'sqlite', filename: ':memory:' }); const connection = await createDatabase({ client: 'sqlite', filename: ':memory:' });
openConnections.push(connection.destroy); openConnections.push(connection.destroy);
@@ -33,6 +33,7 @@ describe('initializeSchema', () => {
expect(tableNames).toContain('topics'); expect(tableNames).toContain('topics');
expect(tableNames).toContain('facts'); expect(tableNames).toContain('facts');
expect(tableNames).toContain('fact_embeddings');
expect(tableNames).toContain('fact_topics'); expect(tableNames).toContain('fact_topics');
expect(tableNames).toContain('topic_relations'); expect(tableNames).toContain('topic_relations');
expect(tableNames).toContain('topic_aliases'); expect(tableNames).toContain('topic_aliases');
@@ -46,6 +47,7 @@ describe('initializeSchema', () => {
const topicsColumns = await sql<{ name: string }>`PRAGMA table_info(topics)`.execute(connection.db); const topicsColumns = await sql<{ name: string }>`PRAGMA table_info(topics)`.execute(connection.db);
const factsColumns = await sql<{ name: string }>`PRAGMA table_info(facts)`.execute(connection.db); const factsColumns = await sql<{ name: string }>`PRAGMA table_info(facts)`.execute(connection.db);
const factEmbeddingsColumns = await sql<{ name: string }>`PRAGMA table_info(fact_embeddings)`.execute(connection.db);
const factTopicsColumns = await sql<{ name: string }>`PRAGMA table_info(fact_topics)`.execute(connection.db); const factTopicsColumns = await sql<{ name: string }>`PRAGMA table_info(fact_topics)`.execute(connection.db);
const topicRelationsColumns = await sql<{ name: string }>`PRAGMA table_info(topic_relations)`.execute(connection.db); const topicRelationsColumns = await sql<{ name: string }>`PRAGMA table_info(topic_relations)`.execute(connection.db);
const topicAliasesColumns = await sql<{ name: string }>`PRAGMA table_info(topic_aliases)`.execute(connection.db); const topicAliasesColumns = await sql<{ name: string }>`PRAGMA table_info(topic_aliases)`.execute(connection.db);
@@ -73,6 +75,16 @@ describe('initializeSchema', () => {
'updated_at', 'updated_at',
]); ]);
expect(factEmbeddingsColumns.rows.map((row) => row.name)).toEqual([
'fact_id',
'model',
'dimensions',
'embedding',
'content_hash',
'created_at',
'updated_at',
]);
expect(factTopicsColumns.rows.map((row) => row.name)).toEqual([ expect(factTopicsColumns.rows.map((row) => row.name)).toEqual([
'fact_id', 'fact_id',
'topic_id', 'topic_id',

View File

@@ -0,0 +1,170 @@
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { IdentityDB } from '../src/core/identity-db';
import type { FactExtractor } from '../src/ingestion/types';
import type { EmbeddingProvider } from '../src/types/api';
class FakeEmbeddingProvider implements EmbeddingProvider {
model = 'fake-semantic-v1';
dimensions = 3;
async embed(input: string): Promise<number[]> {
return embeddingFor(input);
}
async embedMany(inputs: string[]): Promise<number[][]> {
return Promise.all(inputs.map((input) => this.embed(input)));
}
}
function embeddingFor(input: string): number[] {
const normalized = input.toLowerCase();
if (normalized.includes('bun') && normalized.includes('typescript')) {
return [1, 0, 0];
}
if (normalized.includes('tooling') || normalized.includes('runtime')) {
return [0.98, 0.02, 0];
}
if (normalized.includes('typescript')) {
return [0.9, 0.1, 0];
}
if (normalized.includes('python')) {
return [0, 1, 0];
}
if (normalized.includes('database')) {
return [0, 0.2, 0.8];
}
return [0.1, 0.1, 0.1];
}
describe('IdentityDB semantic search', () => {
let db: IdentityDB;
let provider: FakeEmbeddingProvider;
beforeEach(async () => {
provider = new FakeEmbeddingProvider();
db = await IdentityDB.connect({ client: 'sqlite', filename: ':memory:' });
await db.initialize();
await db.addFact({
statement: 'Bun runs TypeScript tooling quickly.',
topics: [
{ name: 'Bun', category: 'entity', granularity: 'concrete' },
{ name: 'TypeScript', category: 'entity', granularity: 'concrete' },
],
});
await db.addFact({
statement: 'TypeScript compiles to JavaScript.',
topics: [
{ name: 'TypeScript', category: 'entity', granularity: 'concrete' },
{ name: 'JavaScript', category: 'entity', granularity: 'concrete' },
],
});
await db.addFact({
statement: 'Python uses indentation syntax.',
topics: [
{ name: 'Python', category: 'entity', granularity: 'concrete' },
],
});
});
afterEach(async () => {
await db.close();
});
it('indexes facts and returns semantic search matches ordered by score', async () => {
await db.indexFactEmbeddings({ provider });
const matches = await db.searchFacts({
query: 'TypeScript runtime tooling',
provider,
limit: 2,
});
expect(matches).toHaveLength(2);
expect(matches[0]?.statement).toBe('Bun runs TypeScript tooling quickly.');
expect(matches[1]?.statement).toBe('TypeScript compiles to JavaScript.');
expect(matches[0]!.score).toBeGreaterThan(matches[1]!.score);
});
it('filters semantic search candidates by topic names', async () => {
await db.indexFactEmbeddings({ provider });
const matches = await db.searchFacts({
query: 'TypeScript runtime tooling',
provider,
topicNames: ['Python'],
limit: 5,
});
expect(matches.map((match) => match.statement)).toEqual(['Python uses indentation syntax.']);
});
it('finds similar facts from an input statement', async () => {
await db.indexFactEmbeddings({ provider });
const matches = await db.findSimilarFacts({
statement: 'Bun makes TypeScript tooling fast.',
provider,
limit: 2,
});
expect(matches[0]?.statement).toBe('Bun runs TypeScript tooling quickly.');
expect(matches[0]!.score).toBeGreaterThan(matches[1]!.score);
});
});
describe('IdentityDB dedup-aware ingestion', () => {
let db: IdentityDB;
let provider: FakeEmbeddingProvider;
let extractor: FactExtractor;
beforeEach(async () => {
provider = new FakeEmbeddingProvider();
extractor = {
async extract(input) {
return {
statement: input,
topics: [
{ name: 'Bun', category: 'entity', granularity: 'concrete' },
{ name: 'TypeScript', category: 'entity', granularity: 'concrete' },
],
};
},
};
db = await IdentityDB.connect({ client: 'sqlite', filename: ':memory:' });
await db.initialize();
});
afterEach(async () => {
await db.close();
});
it('returns the existing fact when ingestion detects a semantic duplicate', async () => {
const first = await db.ingestStatement('Bun runs TypeScript tooling quickly.', {
extractor,
embeddingProvider: provider,
});
const second = await db.ingestStatement('Bun makes TypeScript tooling fast.', {
extractor,
embeddingProvider: provider,
duplicateThreshold: 0.95,
});
const facts = await db.getTopicFacts('TypeScript');
expect(second.id).toBe(first.id);
expect(facts).toHaveLength(1);
expect(facts[0]?.statement).toBe('Bun runs TypeScript tooling quickly.');
});
});