feat: add semantic fact search and embeddings
This commit is contained in:
@@ -2,7 +2,11 @@ import {
|
||||
type ConnectedTopic,
|
||||
type Fact,
|
||||
type FactTopic,
|
||||
type FindSimilarFactsInput,
|
||||
type IndexFactEmbeddingsInput,
|
||||
type ListTopicsOptions,
|
||||
type ScoredFact,
|
||||
type SearchFactsInput,
|
||||
type Topic,
|
||||
type TopicLookupOptions,
|
||||
type TopicWithFacts,
|
||||
@@ -19,11 +23,15 @@ import { IdentityDBError } from './errors';
|
||||
import { initializeSchema } from './migrations';
|
||||
import {
|
||||
canonicalizeTopicName,
|
||||
cosineSimilarity,
|
||||
createContentHash,
|
||||
createId,
|
||||
deserializeEmbedding,
|
||||
mapFactRow,
|
||||
mapTopicRow,
|
||||
normalizeTopicName,
|
||||
nowIsoString,
|
||||
serializeEmbedding,
|
||||
serializeMetadata,
|
||||
} from './utils';
|
||||
import { extractFact } from '../ingestion/extractor';
|
||||
@@ -155,7 +163,107 @@ export class IdentityDB {
|
||||
factInput.metadata = extracted.metadata;
|
||||
}
|
||||
|
||||
return this.addFact(factInput);
|
||||
if (options.embeddingProvider) {
|
||||
const similarFacts = await this.findSimilarFacts({
|
||||
statement: factInput.statement,
|
||||
provider: options.embeddingProvider,
|
||||
topicNames: factInput.topics.map((topic) => topic.name),
|
||||
limit: 1,
|
||||
minimumScore: options.duplicateThreshold ?? 0.97,
|
||||
});
|
||||
|
||||
if (similarFacts[0]) {
|
||||
return similarFacts[0];
|
||||
}
|
||||
}
|
||||
|
||||
const fact = await this.addFact(factInput);
|
||||
|
||||
if (options.embeddingProvider) {
|
||||
await this.indexFactEmbedding(fact.id, { provider: options.embeddingProvider });
|
||||
}
|
||||
|
||||
return fact;
|
||||
}
|
||||
|
||||
async indexFactEmbeddings(input: IndexFactEmbeddingsInput): Promise<void> {
|
||||
const factRows = await this.connection.db.selectFrom('facts').selectAll().orderBy('created_at', 'asc').execute();
|
||||
|
||||
if (factRows.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const embeddings = input.provider.embedMany
|
||||
? await input.provider.embedMany(factRows.map((factRow) => factRow.statement))
|
||||
: await Promise.all(factRows.map((factRow) => input.provider.embed(factRow.statement)));
|
||||
|
||||
if (embeddings.length !== factRows.length) {
|
||||
throw new IdentityDBError('Embedding provider returned a mismatched number of embeddings.');
|
||||
}
|
||||
|
||||
await this.connection.db.transaction().execute(async (trx) => {
|
||||
for (let index = 0; index < factRows.length; index += 1) {
|
||||
const factRow = factRows[index]!;
|
||||
const embedding = embeddings[index]!;
|
||||
this.assertEmbeddingShape(embedding, input.provider.dimensions);
|
||||
await this.upsertFactEmbeddingRecord(trx, factRow.id, factRow.statement, embedding, input.provider.model);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async indexFactEmbedding(factId: string, input: IndexFactEmbeddingsInput): Promise<void> {
|
||||
const factRow = await this.connection.db
|
||||
.selectFrom('facts')
|
||||
.selectAll()
|
||||
.where('id', '=', factId)
|
||||
.executeTakeFirst();
|
||||
|
||||
if (!factRow) {
|
||||
throw new IdentityDBError(`Fact not found: ${factId}`);
|
||||
}
|
||||
|
||||
const embedding = await input.provider.embed(factRow.statement);
|
||||
this.assertEmbeddingShape(embedding, input.provider.dimensions);
|
||||
|
||||
await this.connection.db.transaction().execute(async (trx) => {
|
||||
await this.upsertFactEmbeddingRecord(trx, factRow.id, factRow.statement, embedding, input.provider.model);
|
||||
});
|
||||
}
|
||||
|
||||
async searchFacts(input: SearchFactsInput): Promise<ScoredFact[]> {
|
||||
const queryText = input.query.trim();
|
||||
if (queryText.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const queryEmbedding = await input.provider.embed(queryText);
|
||||
this.assertEmbeddingShape(queryEmbedding, input.provider.dimensions);
|
||||
|
||||
return this.searchFactsByEmbedding({
|
||||
providerModel: input.provider.model,
|
||||
queryEmbedding,
|
||||
topicNames: input.topicNames,
|
||||
limit: input.limit,
|
||||
minimumScore: input.minimumScore,
|
||||
});
|
||||
}
|
||||
|
||||
async findSimilarFacts(input: FindSimilarFactsInput): Promise<ScoredFact[]> {
|
||||
const statement = input.statement.trim();
|
||||
if (statement.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const queryEmbedding = await input.provider.embed(statement);
|
||||
this.assertEmbeddingShape(queryEmbedding, input.provider.dimensions);
|
||||
|
||||
return this.searchFactsByEmbedding({
|
||||
providerModel: input.provider.model,
|
||||
queryEmbedding,
|
||||
topicNames: input.topicNames,
|
||||
limit: input.limit,
|
||||
minimumScore: input.minimumScore,
|
||||
});
|
||||
}
|
||||
|
||||
async linkTopics(input: LinkTopicsInput): Promise<void> {
|
||||
@@ -413,6 +521,122 @@ export class IdentityDB {
|
||||
}));
|
||||
}
|
||||
|
||||
private async searchFactsByEmbedding(input: {
|
||||
providerModel: string;
|
||||
queryEmbedding: number[];
|
||||
topicNames?: string[] | undefined;
|
||||
limit?: number | undefined;
|
||||
minimumScore?: number | undefined;
|
||||
}): Promise<ScoredFact[]> {
|
||||
const topicIds = await this.resolveTopicIds(input.topicNames);
|
||||
if (topicIds === null) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const factRows = topicIds.length > 0
|
||||
? await findFactRowsConnectingTopicIds(this.connection.db, topicIds)
|
||||
: await this.connection.db
|
||||
.selectFrom('facts')
|
||||
.innerJoin('fact_embeddings', 'fact_embeddings.fact_id', 'facts.id')
|
||||
.selectAll('facts')
|
||||
.where('fact_embeddings.model', '=', input.providerModel)
|
||||
.orderBy('facts.created_at', 'asc')
|
||||
.execute();
|
||||
|
||||
if (factRows.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const embeddingRowsQuery = this.connection.db
|
||||
.selectFrom('fact_embeddings')
|
||||
.selectAll()
|
||||
.where('model', '=', input.providerModel);
|
||||
|
||||
const embeddingRows = factRows.length > 0
|
||||
? await embeddingRowsQuery.where('fact_id', 'in', factRows.map((factRow) => factRow.id)).execute()
|
||||
: [];
|
||||
|
||||
const embeddingsByFactId = new Map(
|
||||
embeddingRows.map((embeddingRow) => [embeddingRow.fact_id, deserializeEmbedding(embeddingRow.embedding)]),
|
||||
);
|
||||
|
||||
const scoredRows = factRows
|
||||
.map((factRow) => ({
|
||||
factRow,
|
||||
score: cosineSimilarity(input.queryEmbedding, embeddingsByFactId.get(factRow.id) ?? []),
|
||||
}))
|
||||
.filter((entry) => entry.score >= (input.minimumScore ?? 0))
|
||||
.sort((left, right) => {
|
||||
if (right.score !== left.score) {
|
||||
return right.score - left.score;
|
||||
}
|
||||
return left.factRow.created_at.localeCompare(right.factRow.created_at);
|
||||
})
|
||||
.slice(0, input.limit ?? 5);
|
||||
|
||||
if (scoredRows.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const hydratedFacts = await this.hydrateFacts(scoredRows.map((entry) => entry.factRow));
|
||||
const factsById = new Map(hydratedFacts.map((fact) => [fact.id, fact]));
|
||||
|
||||
return scoredRows.map((entry) => ({
|
||||
...factsById.get(entry.factRow.id)!,
|
||||
score: entry.score,
|
||||
}));
|
||||
}
|
||||
|
||||
private async resolveTopicIds(topicNames?: string[]): Promise<string[] | null> {
|
||||
if (!topicNames || topicNames.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const topicRows = await Promise.all(topicNames.map((topicName) => this.getRequiredTopicRow(topicName)));
|
||||
if (topicRows.some((topicRow) => !topicRow)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return topicRows.map((topicRow) => topicRow!.id);
|
||||
}
|
||||
|
||||
private async upsertFactEmbeddingRecord(
|
||||
executor: DatabaseExecutor,
|
||||
factId: string,
|
||||
statement: string,
|
||||
embedding: number[],
|
||||
model: string,
|
||||
): Promise<void> {
|
||||
const timestamp = nowIsoString();
|
||||
|
||||
await executor
|
||||
.deleteFrom('fact_embeddings')
|
||||
.where('fact_id', '=', factId)
|
||||
.where('model', '=', model)
|
||||
.execute();
|
||||
|
||||
await executor
|
||||
.insertInto('fact_embeddings')
|
||||
.values({
|
||||
fact_id: factId,
|
||||
model,
|
||||
dimensions: embedding.length,
|
||||
embedding: serializeEmbedding(embedding),
|
||||
content_hash: createContentHash(statement),
|
||||
created_at: timestamp,
|
||||
updated_at: timestamp,
|
||||
})
|
||||
.execute();
|
||||
}
|
||||
|
||||
private assertEmbeddingShape(embedding: number[], expectedDimensions: number): void {
|
||||
if (embedding.length !== expectedDimensions) {
|
||||
throw new IdentityDBError(
|
||||
`Embedding dimension mismatch. Expected ${expectedDimensions}, received ${embedding.length}.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private async upsertTopicInExecutor(
|
||||
executor: DatabaseExecutor,
|
||||
input: UpsertTopicInput,
|
||||
|
||||
Reference in New Issue
Block a user