feat: add semantic fact search and embeddings

This commit is contained in:
2026-05-11 12:05:47 +09:00
parent 428f5021e8
commit 810f4a6bf2
10 changed files with 529 additions and 4 deletions

View File

@@ -2,7 +2,11 @@ import {
type ConnectedTopic,
type Fact,
type FactTopic,
type FindSimilarFactsInput,
type IndexFactEmbeddingsInput,
type ListTopicsOptions,
type ScoredFact,
type SearchFactsInput,
type Topic,
type TopicLookupOptions,
type TopicWithFacts,
@@ -19,11 +23,15 @@ import { IdentityDBError } from './errors';
import { initializeSchema } from './migrations';
import {
canonicalizeTopicName,
cosineSimilarity,
createContentHash,
createId,
deserializeEmbedding,
mapFactRow,
mapTopicRow,
normalizeTopicName,
nowIsoString,
serializeEmbedding,
serializeMetadata,
} from './utils';
import { extractFact } from '../ingestion/extractor';
@@ -155,7 +163,107 @@ export class IdentityDB {
factInput.metadata = extracted.metadata;
}
return this.addFact(factInput);
if (options.embeddingProvider) {
const similarFacts = await this.findSimilarFacts({
statement: factInput.statement,
provider: options.embeddingProvider,
topicNames: factInput.topics.map((topic) => topic.name),
limit: 1,
minimumScore: options.duplicateThreshold ?? 0.97,
});
if (similarFacts[0]) {
return similarFacts[0];
}
}
const fact = await this.addFact(factInput);
if (options.embeddingProvider) {
await this.indexFactEmbedding(fact.id, { provider: options.embeddingProvider });
}
return fact;
}
async indexFactEmbeddings(input: IndexFactEmbeddingsInput): Promise<void> {
const factRows = await this.connection.db.selectFrom('facts').selectAll().orderBy('created_at', 'asc').execute();
if (factRows.length === 0) {
return;
}
const embeddings = input.provider.embedMany
? await input.provider.embedMany(factRows.map((factRow) => factRow.statement))
: await Promise.all(factRows.map((factRow) => input.provider.embed(factRow.statement)));
if (embeddings.length !== factRows.length) {
throw new IdentityDBError('Embedding provider returned a mismatched number of embeddings.');
}
await this.connection.db.transaction().execute(async (trx) => {
for (let index = 0; index < factRows.length; index += 1) {
const factRow = factRows[index]!;
const embedding = embeddings[index]!;
this.assertEmbeddingShape(embedding, input.provider.dimensions);
await this.upsertFactEmbeddingRecord(trx, factRow.id, factRow.statement, embedding, input.provider.model);
}
});
}
async indexFactEmbedding(factId: string, input: IndexFactEmbeddingsInput): Promise<void> {
const factRow = await this.connection.db
.selectFrom('facts')
.selectAll()
.where('id', '=', factId)
.executeTakeFirst();
if (!factRow) {
throw new IdentityDBError(`Fact not found: ${factId}`);
}
const embedding = await input.provider.embed(factRow.statement);
this.assertEmbeddingShape(embedding, input.provider.dimensions);
await this.connection.db.transaction().execute(async (trx) => {
await this.upsertFactEmbeddingRecord(trx, factRow.id, factRow.statement, embedding, input.provider.model);
});
}
async searchFacts(input: SearchFactsInput): Promise<ScoredFact[]> {
const queryText = input.query.trim();
if (queryText.length === 0) {
return [];
}
const queryEmbedding = await input.provider.embed(queryText);
this.assertEmbeddingShape(queryEmbedding, input.provider.dimensions);
return this.searchFactsByEmbedding({
providerModel: input.provider.model,
queryEmbedding,
topicNames: input.topicNames,
limit: input.limit,
minimumScore: input.minimumScore,
});
}
async findSimilarFacts(input: FindSimilarFactsInput): Promise<ScoredFact[]> {
const statement = input.statement.trim();
if (statement.length === 0) {
return [];
}
const queryEmbedding = await input.provider.embed(statement);
this.assertEmbeddingShape(queryEmbedding, input.provider.dimensions);
return this.searchFactsByEmbedding({
providerModel: input.provider.model,
queryEmbedding,
topicNames: input.topicNames,
limit: input.limit,
minimumScore: input.minimumScore,
});
}
async linkTopics(input: LinkTopicsInput): Promise<void> {
@@ -413,6 +521,122 @@ export class IdentityDB {
}));
}
private async searchFactsByEmbedding(input: {
providerModel: string;
queryEmbedding: number[];
topicNames?: string[] | undefined;
limit?: number | undefined;
minimumScore?: number | undefined;
}): Promise<ScoredFact[]> {
const topicIds = await this.resolveTopicIds(input.topicNames);
if (topicIds === null) {
return [];
}
const factRows = topicIds.length > 0
? await findFactRowsConnectingTopicIds(this.connection.db, topicIds)
: await this.connection.db
.selectFrom('facts')
.innerJoin('fact_embeddings', 'fact_embeddings.fact_id', 'facts.id')
.selectAll('facts')
.where('fact_embeddings.model', '=', input.providerModel)
.orderBy('facts.created_at', 'asc')
.execute();
if (factRows.length === 0) {
return [];
}
const embeddingRowsQuery = this.connection.db
.selectFrom('fact_embeddings')
.selectAll()
.where('model', '=', input.providerModel);
const embeddingRows = factRows.length > 0
? await embeddingRowsQuery.where('fact_id', 'in', factRows.map((factRow) => factRow.id)).execute()
: [];
const embeddingsByFactId = new Map(
embeddingRows.map((embeddingRow) => [embeddingRow.fact_id, deserializeEmbedding(embeddingRow.embedding)]),
);
const scoredRows = factRows
.map((factRow) => ({
factRow,
score: cosineSimilarity(input.queryEmbedding, embeddingsByFactId.get(factRow.id) ?? []),
}))
.filter((entry) => entry.score >= (input.minimumScore ?? 0))
.sort((left, right) => {
if (right.score !== left.score) {
return right.score - left.score;
}
return left.factRow.created_at.localeCompare(right.factRow.created_at);
})
.slice(0, input.limit ?? 5);
if (scoredRows.length === 0) {
return [];
}
const hydratedFacts = await this.hydrateFacts(scoredRows.map((entry) => entry.factRow));
const factsById = new Map(hydratedFacts.map((fact) => [fact.id, fact]));
return scoredRows.map((entry) => ({
...factsById.get(entry.factRow.id)!,
score: entry.score,
}));
}
private async resolveTopicIds(topicNames?: string[]): Promise<string[] | null> {
if (!topicNames || topicNames.length === 0) {
return [];
}
const topicRows = await Promise.all(topicNames.map((topicName) => this.getRequiredTopicRow(topicName)));
if (topicRows.some((topicRow) => !topicRow)) {
return null;
}
return topicRows.map((topicRow) => topicRow!.id);
}
private async upsertFactEmbeddingRecord(
executor: DatabaseExecutor,
factId: string,
statement: string,
embedding: number[],
model: string,
): Promise<void> {
const timestamp = nowIsoString();
await executor
.deleteFrom('fact_embeddings')
.where('fact_id', '=', factId)
.where('model', '=', model)
.execute();
await executor
.insertInto('fact_embeddings')
.values({
fact_id: factId,
model,
dimensions: embedding.length,
embedding: serializeEmbedding(embedding),
content_hash: createContentHash(statement),
created_at: timestamp,
updated_at: timestamp,
})
.execute();
}
private assertEmbeddingShape(embedding: number[], expectedDimensions: number): void {
if (embedding.length !== expectedDimensions) {
throw new IdentityDBError(
`Embedding dimension mismatch. Expected ${expectedDimensions}, received ${embedding.length}.`,
);
}
}
private async upsertTopicInExecutor(
executor: DatabaseExecutor,
input: UpsertTopicInput,

View File

@@ -2,6 +2,7 @@ import type { Kysely } from 'kysely';
import {
FACTS_TABLE,
FACT_EMBEDDINGS_TABLE,
FACT_TOPICS_TABLE,
TOPIC_ALIASES_TABLE,
TOPIC_RELATIONS_TABLE,
@@ -39,6 +40,21 @@ export async function initializeSchema(
.addColumn('updated_at', 'text', (column) => column.notNull())
.execute();
await db.schema
.createTable(FACT_EMBEDDINGS_TABLE)
.ifNotExists()
.addColumn('fact_id', 'text', (column) =>
column.notNull().references(`${FACTS_TABLE}.id`).onDelete('cascade'),
)
.addColumn('model', 'text', (column) => column.notNull())
.addColumn('dimensions', 'integer', (column) => column.notNull())
.addColumn('embedding', 'text', (column) => column.notNull())
.addColumn('content_hash', 'text', (column) => column.notNull())
.addColumn('created_at', 'text', (column) => column.notNull())
.addColumn('updated_at', 'text', (column) => column.notNull())
.addPrimaryKeyConstraint('fact_embeddings_pk', ['fact_id', 'model'])
.execute();
await db.schema
.createTable(FACT_TOPICS_TABLE)
.ifNotExists()
@@ -96,6 +112,13 @@ export async function initializeSchema(
.column('fact_id')
.execute();
await db.schema
.createIndex('fact_embeddings_model_idx')
.ifNotExists()
.on(FACT_EMBEDDINGS_TABLE)
.column('model')
.execute();
await db.schema
.createIndex('topic_relations_parent_topic_id_idx')
.ifNotExists()

View File

@@ -3,6 +3,7 @@ export const FACTS_TABLE = 'facts';
export const FACT_TOPICS_TABLE = 'fact_topics';
export const TOPIC_RELATIONS_TABLE = 'topic_relations';
export const TOPIC_ALIASES_TABLE = 'topic_aliases';
export const FACT_EMBEDDINGS_TABLE = 'fact_embeddings';
export const TOPIC_COLUMNS = [
'id',
@@ -51,3 +52,13 @@ export const TOPIC_ALIAS_COLUMNS = [
'created_at',
'updated_at',
] as const;
export const FACT_EMBEDDING_COLUMNS = [
'fact_id',
'model',
'dimensions',
'embedding',
'content_hash',
'created_at',
'updated_at',
] as const;

View File

@@ -1,4 +1,4 @@
import { randomUUID } from 'node:crypto';
import { createHash, randomUUID } from 'node:crypto';
import type { Fact, FactTopic, Topic } from '../types/api';
import type { FactRecord, TopicRecord } from '../types/domain';
@@ -35,6 +35,42 @@ export function deserializeMetadata(metadata: string | null): unknown | null {
return JSON.parse(metadata);
}
export function serializeEmbedding(embedding: number[]): string {
return JSON.stringify(embedding);
}
export function deserializeEmbedding(embedding: string): number[] {
return JSON.parse(embedding) as number[];
}
export function createContentHash(input: string): string {
return createHash('sha256').update(input).digest('hex');
}
export function cosineSimilarity(left: number[], right: number[]): number {
if (left.length === 0 || left.length !== right.length) {
return 0;
}
let dot = 0;
let leftMagnitude = 0;
let rightMagnitude = 0;
for (let index = 0; index < left.length; index += 1) {
const leftValue = left[index] ?? 0;
const rightValue = right[index] ?? 0;
dot += leftValue * rightValue;
leftMagnitude += leftValue * leftValue;
rightMagnitude += rightValue * rightValue;
}
if (leftMagnitude === 0 || rightMagnitude === 0) {
return 0;
}
return dot / (Math.sqrt(leftMagnitude) * Math.sqrt(rightMagnitude));
}
export function mapTopicRow(record: TopicRecord): Topic {
return {
id: record.id,

View File

@@ -1,4 +1,8 @@
import type { AddFactInput, TopicLinkInput } from '../types/api';
import type {
AddFactInput,
EmbeddingProvider,
TopicLinkInput,
} from '../types/api';
export interface ExtractedFact {
statement?: string;
@@ -15,4 +19,6 @@ export interface FactExtractor {
export interface IngestStatementOptions {
extractor: FactExtractor;
embeddingProvider?: EmbeddingProvider;
duplicateThreshold?: number;
}

View File

@@ -71,3 +71,34 @@ export interface ListTopicsOptions {
includeFacts?: boolean;
limit?: number;
}
export interface EmbeddingProvider {
model: string;
dimensions: number;
embed(input: string): Promise<number[]>;
embedMany?(inputs: string[]): Promise<number[][]>;
}
export interface IndexFactEmbeddingsInput {
provider: EmbeddingProvider;
}
export interface SearchFactsInput {
query: string;
provider: EmbeddingProvider;
topicNames?: string[];
limit?: number;
minimumScore?: number;
}
export interface FindSimilarFactsInput {
statement: string;
provider: EmbeddingProvider;
topicNames?: string[];
limit?: number;
minimumScore?: number;
}
export interface ScoredFact extends Fact {
score: number;
}

View File

@@ -1,4 +1,5 @@
import type {
FactEmbeddingRecord,
FactRecord,
FactTopicRecord,
TopicAliasRecord,
@@ -12,4 +13,5 @@ export interface IdentityDatabaseSchema {
fact_topics: FactTopicRecord;
topic_relations: TopicRelationRecord;
topic_aliases: TopicAliasRecord;
fact_embeddings: FactEmbeddingRecord;
}

View File

@@ -52,3 +52,13 @@ export interface TopicAliasRecord {
created_at: string;
updated_at: string;
}
export interface FactEmbeddingRecord {
fact_id: string;
model: string;
dimensions: number;
embedding: string;
content_hash: string;
created_at: string;
updated_at: string;
}