feat: add semantic fact search and embeddings

This commit is contained in:
2026-05-11 12:05:47 +09:00
parent 428f5021e8
commit 810f4a6bf2
10 changed files with 529 additions and 4 deletions

View File

@@ -2,7 +2,11 @@ import {
type ConnectedTopic,
type Fact,
type FactTopic,
type FindSimilarFactsInput,
type IndexFactEmbeddingsInput,
type ListTopicsOptions,
type ScoredFact,
type SearchFactsInput,
type Topic,
type TopicLookupOptions,
type TopicWithFacts,
@@ -19,11 +23,15 @@ import { IdentityDBError } from './errors';
import { initializeSchema } from './migrations';
import {
canonicalizeTopicName,
cosineSimilarity,
createContentHash,
createId,
deserializeEmbedding,
mapFactRow,
mapTopicRow,
normalizeTopicName,
nowIsoString,
serializeEmbedding,
serializeMetadata,
} from './utils';
import { extractFact } from '../ingestion/extractor';
@@ -155,7 +163,107 @@ export class IdentityDB {
factInput.metadata = extracted.metadata;
}
return this.addFact(factInput);
if (options.embeddingProvider) {
const similarFacts = await this.findSimilarFacts({
statement: factInput.statement,
provider: options.embeddingProvider,
topicNames: factInput.topics.map((topic) => topic.name),
limit: 1,
minimumScore: options.duplicateThreshold ?? 0.97,
});
if (similarFacts[0]) {
return similarFacts[0];
}
}
const fact = await this.addFact(factInput);
if (options.embeddingProvider) {
await this.indexFactEmbedding(fact.id, { provider: options.embeddingProvider });
}
return fact;
}
async indexFactEmbeddings(input: IndexFactEmbeddingsInput): Promise<void> {
const factRows = await this.connection.db.selectFrom('facts').selectAll().orderBy('created_at', 'asc').execute();
if (factRows.length === 0) {
return;
}
const embeddings = input.provider.embedMany
? await input.provider.embedMany(factRows.map((factRow) => factRow.statement))
: await Promise.all(factRows.map((factRow) => input.provider.embed(factRow.statement)));
if (embeddings.length !== factRows.length) {
throw new IdentityDBError('Embedding provider returned a mismatched number of embeddings.');
}
await this.connection.db.transaction().execute(async (trx) => {
for (let index = 0; index < factRows.length; index += 1) {
const factRow = factRows[index]!;
const embedding = embeddings[index]!;
this.assertEmbeddingShape(embedding, input.provider.dimensions);
await this.upsertFactEmbeddingRecord(trx, factRow.id, factRow.statement, embedding, input.provider.model);
}
});
}
async indexFactEmbedding(factId: string, input: IndexFactEmbeddingsInput): Promise<void> {
const factRow = await this.connection.db
.selectFrom('facts')
.selectAll()
.where('id', '=', factId)
.executeTakeFirst();
if (!factRow) {
throw new IdentityDBError(`Fact not found: ${factId}`);
}
const embedding = await input.provider.embed(factRow.statement);
this.assertEmbeddingShape(embedding, input.provider.dimensions);
await this.connection.db.transaction().execute(async (trx) => {
await this.upsertFactEmbeddingRecord(trx, factRow.id, factRow.statement, embedding, input.provider.model);
});
}
async searchFacts(input: SearchFactsInput): Promise<ScoredFact[]> {
const queryText = input.query.trim();
if (queryText.length === 0) {
return [];
}
const queryEmbedding = await input.provider.embed(queryText);
this.assertEmbeddingShape(queryEmbedding, input.provider.dimensions);
return this.searchFactsByEmbedding({
providerModel: input.provider.model,
queryEmbedding,
topicNames: input.topicNames,
limit: input.limit,
minimumScore: input.minimumScore,
});
}
async findSimilarFacts(input: FindSimilarFactsInput): Promise<ScoredFact[]> {
const statement = input.statement.trim();
if (statement.length === 0) {
return [];
}
const queryEmbedding = await input.provider.embed(statement);
this.assertEmbeddingShape(queryEmbedding, input.provider.dimensions);
return this.searchFactsByEmbedding({
providerModel: input.provider.model,
queryEmbedding,
topicNames: input.topicNames,
limit: input.limit,
minimumScore: input.minimumScore,
});
}
async linkTopics(input: LinkTopicsInput): Promise<void> {
@@ -413,6 +521,122 @@ export class IdentityDB {
}));
}
private async searchFactsByEmbedding(input: {
providerModel: string;
queryEmbedding: number[];
topicNames?: string[] | undefined;
limit?: number | undefined;
minimumScore?: number | undefined;
}): Promise<ScoredFact[]> {
const topicIds = await this.resolveTopicIds(input.topicNames);
if (topicIds === null) {
return [];
}
const factRows = topicIds.length > 0
? await findFactRowsConnectingTopicIds(this.connection.db, topicIds)
: await this.connection.db
.selectFrom('facts')
.innerJoin('fact_embeddings', 'fact_embeddings.fact_id', 'facts.id')
.selectAll('facts')
.where('fact_embeddings.model', '=', input.providerModel)
.orderBy('facts.created_at', 'asc')
.execute();
if (factRows.length === 0) {
return [];
}
const embeddingRowsQuery = this.connection.db
.selectFrom('fact_embeddings')
.selectAll()
.where('model', '=', input.providerModel);
const embeddingRows = factRows.length > 0
? await embeddingRowsQuery.where('fact_id', 'in', factRows.map((factRow) => factRow.id)).execute()
: [];
const embeddingsByFactId = new Map(
embeddingRows.map((embeddingRow) => [embeddingRow.fact_id, deserializeEmbedding(embeddingRow.embedding)]),
);
const scoredRows = factRows
.map((factRow) => ({
factRow,
score: cosineSimilarity(input.queryEmbedding, embeddingsByFactId.get(factRow.id) ?? []),
}))
.filter((entry) => entry.score >= (input.minimumScore ?? 0))
.sort((left, right) => {
if (right.score !== left.score) {
return right.score - left.score;
}
return left.factRow.created_at.localeCompare(right.factRow.created_at);
})
.slice(0, input.limit ?? 5);
if (scoredRows.length === 0) {
return [];
}
const hydratedFacts = await this.hydrateFacts(scoredRows.map((entry) => entry.factRow));
const factsById = new Map(hydratedFacts.map((fact) => [fact.id, fact]));
return scoredRows.map((entry) => ({
...factsById.get(entry.factRow.id)!,
score: entry.score,
}));
}
private async resolveTopicIds(topicNames?: string[]): Promise<string[] | null> {
if (!topicNames || topicNames.length === 0) {
return [];
}
const topicRows = await Promise.all(topicNames.map((topicName) => this.getRequiredTopicRow(topicName)));
if (topicRows.some((topicRow) => !topicRow)) {
return null;
}
return topicRows.map((topicRow) => topicRow!.id);
}
private async upsertFactEmbeddingRecord(
executor: DatabaseExecutor,
factId: string,
statement: string,
embedding: number[],
model: string,
): Promise<void> {
const timestamp = nowIsoString();
await executor
.deleteFrom('fact_embeddings')
.where('fact_id', '=', factId)
.where('model', '=', model)
.execute();
await executor
.insertInto('fact_embeddings')
.values({
fact_id: factId,
model,
dimensions: embedding.length,
embedding: serializeEmbedding(embedding),
content_hash: createContentHash(statement),
created_at: timestamp,
updated_at: timestamp,
})
.execute();
}
private assertEmbeddingShape(embedding: number[], expectedDimensions: number): void {
if (embedding.length !== expectedDimensions) {
throw new IdentityDBError(
`Embedding dimension mismatch. Expected ${expectedDimensions}, received ${embedding.length}.`,
);
}
}
private async upsertTopicInExecutor(
executor: DatabaseExecutor,
input: UpsertTopicInput,