feat: add semantic fact search and embeddings
This commit is contained in:
@@ -2,7 +2,11 @@ import {
|
|||||||
type ConnectedTopic,
|
type ConnectedTopic,
|
||||||
type Fact,
|
type Fact,
|
||||||
type FactTopic,
|
type FactTopic,
|
||||||
|
type FindSimilarFactsInput,
|
||||||
|
type IndexFactEmbeddingsInput,
|
||||||
type ListTopicsOptions,
|
type ListTopicsOptions,
|
||||||
|
type ScoredFact,
|
||||||
|
type SearchFactsInput,
|
||||||
type Topic,
|
type Topic,
|
||||||
type TopicLookupOptions,
|
type TopicLookupOptions,
|
||||||
type TopicWithFacts,
|
type TopicWithFacts,
|
||||||
@@ -19,11 +23,15 @@ import { IdentityDBError } from './errors';
|
|||||||
import { initializeSchema } from './migrations';
|
import { initializeSchema } from './migrations';
|
||||||
import {
|
import {
|
||||||
canonicalizeTopicName,
|
canonicalizeTopicName,
|
||||||
|
cosineSimilarity,
|
||||||
|
createContentHash,
|
||||||
createId,
|
createId,
|
||||||
|
deserializeEmbedding,
|
||||||
mapFactRow,
|
mapFactRow,
|
||||||
mapTopicRow,
|
mapTopicRow,
|
||||||
normalizeTopicName,
|
normalizeTopicName,
|
||||||
nowIsoString,
|
nowIsoString,
|
||||||
|
serializeEmbedding,
|
||||||
serializeMetadata,
|
serializeMetadata,
|
||||||
} from './utils';
|
} from './utils';
|
||||||
import { extractFact } from '../ingestion/extractor';
|
import { extractFact } from '../ingestion/extractor';
|
||||||
@@ -155,7 +163,107 @@ export class IdentityDB {
|
|||||||
factInput.metadata = extracted.metadata;
|
factInput.metadata = extracted.metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
return this.addFact(factInput);
|
if (options.embeddingProvider) {
|
||||||
|
const similarFacts = await this.findSimilarFacts({
|
||||||
|
statement: factInput.statement,
|
||||||
|
provider: options.embeddingProvider,
|
||||||
|
topicNames: factInput.topics.map((topic) => topic.name),
|
||||||
|
limit: 1,
|
||||||
|
minimumScore: options.duplicateThreshold ?? 0.97,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (similarFacts[0]) {
|
||||||
|
return similarFacts[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const fact = await this.addFact(factInput);
|
||||||
|
|
||||||
|
if (options.embeddingProvider) {
|
||||||
|
await this.indexFactEmbedding(fact.id, { provider: options.embeddingProvider });
|
||||||
|
}
|
||||||
|
|
||||||
|
return fact;
|
||||||
|
}
|
||||||
|
|
||||||
|
async indexFactEmbeddings(input: IndexFactEmbeddingsInput): Promise<void> {
|
||||||
|
const factRows = await this.connection.db.selectFrom('facts').selectAll().orderBy('created_at', 'asc').execute();
|
||||||
|
|
||||||
|
if (factRows.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const embeddings = input.provider.embedMany
|
||||||
|
? await input.provider.embedMany(factRows.map((factRow) => factRow.statement))
|
||||||
|
: await Promise.all(factRows.map((factRow) => input.provider.embed(factRow.statement)));
|
||||||
|
|
||||||
|
if (embeddings.length !== factRows.length) {
|
||||||
|
throw new IdentityDBError('Embedding provider returned a mismatched number of embeddings.');
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.connection.db.transaction().execute(async (trx) => {
|
||||||
|
for (let index = 0; index < factRows.length; index += 1) {
|
||||||
|
const factRow = factRows[index]!;
|
||||||
|
const embedding = embeddings[index]!;
|
||||||
|
this.assertEmbeddingShape(embedding, input.provider.dimensions);
|
||||||
|
await this.upsertFactEmbeddingRecord(trx, factRow.id, factRow.statement, embedding, input.provider.model);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async indexFactEmbedding(factId: string, input: IndexFactEmbeddingsInput): Promise<void> {
|
||||||
|
const factRow = await this.connection.db
|
||||||
|
.selectFrom('facts')
|
||||||
|
.selectAll()
|
||||||
|
.where('id', '=', factId)
|
||||||
|
.executeTakeFirst();
|
||||||
|
|
||||||
|
if (!factRow) {
|
||||||
|
throw new IdentityDBError(`Fact not found: ${factId}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const embedding = await input.provider.embed(factRow.statement);
|
||||||
|
this.assertEmbeddingShape(embedding, input.provider.dimensions);
|
||||||
|
|
||||||
|
await this.connection.db.transaction().execute(async (trx) => {
|
||||||
|
await this.upsertFactEmbeddingRecord(trx, factRow.id, factRow.statement, embedding, input.provider.model);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async searchFacts(input: SearchFactsInput): Promise<ScoredFact[]> {
|
||||||
|
const queryText = input.query.trim();
|
||||||
|
if (queryText.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const queryEmbedding = await input.provider.embed(queryText);
|
||||||
|
this.assertEmbeddingShape(queryEmbedding, input.provider.dimensions);
|
||||||
|
|
||||||
|
return this.searchFactsByEmbedding({
|
||||||
|
providerModel: input.provider.model,
|
||||||
|
queryEmbedding,
|
||||||
|
topicNames: input.topicNames,
|
||||||
|
limit: input.limit,
|
||||||
|
minimumScore: input.minimumScore,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async findSimilarFacts(input: FindSimilarFactsInput): Promise<ScoredFact[]> {
|
||||||
|
const statement = input.statement.trim();
|
||||||
|
if (statement.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const queryEmbedding = await input.provider.embed(statement);
|
||||||
|
this.assertEmbeddingShape(queryEmbedding, input.provider.dimensions);
|
||||||
|
|
||||||
|
return this.searchFactsByEmbedding({
|
||||||
|
providerModel: input.provider.model,
|
||||||
|
queryEmbedding,
|
||||||
|
topicNames: input.topicNames,
|
||||||
|
limit: input.limit,
|
||||||
|
minimumScore: input.minimumScore,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
async linkTopics(input: LinkTopicsInput): Promise<void> {
|
async linkTopics(input: LinkTopicsInput): Promise<void> {
|
||||||
@@ -413,6 +521,122 @@ export class IdentityDB {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async searchFactsByEmbedding(input: {
|
||||||
|
providerModel: string;
|
||||||
|
queryEmbedding: number[];
|
||||||
|
topicNames?: string[] | undefined;
|
||||||
|
limit?: number | undefined;
|
||||||
|
minimumScore?: number | undefined;
|
||||||
|
}): Promise<ScoredFact[]> {
|
||||||
|
const topicIds = await this.resolveTopicIds(input.topicNames);
|
||||||
|
if (topicIds === null) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const factRows = topicIds.length > 0
|
||||||
|
? await findFactRowsConnectingTopicIds(this.connection.db, topicIds)
|
||||||
|
: await this.connection.db
|
||||||
|
.selectFrom('facts')
|
||||||
|
.innerJoin('fact_embeddings', 'fact_embeddings.fact_id', 'facts.id')
|
||||||
|
.selectAll('facts')
|
||||||
|
.where('fact_embeddings.model', '=', input.providerModel)
|
||||||
|
.orderBy('facts.created_at', 'asc')
|
||||||
|
.execute();
|
||||||
|
|
||||||
|
if (factRows.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const embeddingRowsQuery = this.connection.db
|
||||||
|
.selectFrom('fact_embeddings')
|
||||||
|
.selectAll()
|
||||||
|
.where('model', '=', input.providerModel);
|
||||||
|
|
||||||
|
const embeddingRows = factRows.length > 0
|
||||||
|
? await embeddingRowsQuery.where('fact_id', 'in', factRows.map((factRow) => factRow.id)).execute()
|
||||||
|
: [];
|
||||||
|
|
||||||
|
const embeddingsByFactId = new Map(
|
||||||
|
embeddingRows.map((embeddingRow) => [embeddingRow.fact_id, deserializeEmbedding(embeddingRow.embedding)]),
|
||||||
|
);
|
||||||
|
|
||||||
|
const scoredRows = factRows
|
||||||
|
.map((factRow) => ({
|
||||||
|
factRow,
|
||||||
|
score: cosineSimilarity(input.queryEmbedding, embeddingsByFactId.get(factRow.id) ?? []),
|
||||||
|
}))
|
||||||
|
.filter((entry) => entry.score >= (input.minimumScore ?? 0))
|
||||||
|
.sort((left, right) => {
|
||||||
|
if (right.score !== left.score) {
|
||||||
|
return right.score - left.score;
|
||||||
|
}
|
||||||
|
return left.factRow.created_at.localeCompare(right.factRow.created_at);
|
||||||
|
})
|
||||||
|
.slice(0, input.limit ?? 5);
|
||||||
|
|
||||||
|
if (scoredRows.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const hydratedFacts = await this.hydrateFacts(scoredRows.map((entry) => entry.factRow));
|
||||||
|
const factsById = new Map(hydratedFacts.map((fact) => [fact.id, fact]));
|
||||||
|
|
||||||
|
return scoredRows.map((entry) => ({
|
||||||
|
...factsById.get(entry.factRow.id)!,
|
||||||
|
score: entry.score,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
private async resolveTopicIds(topicNames?: string[]): Promise<string[] | null> {
|
||||||
|
if (!topicNames || topicNames.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const topicRows = await Promise.all(topicNames.map((topicName) => this.getRequiredTopicRow(topicName)));
|
||||||
|
if (topicRows.some((topicRow) => !topicRow)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return topicRows.map((topicRow) => topicRow!.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async upsertFactEmbeddingRecord(
|
||||||
|
executor: DatabaseExecutor,
|
||||||
|
factId: string,
|
||||||
|
statement: string,
|
||||||
|
embedding: number[],
|
||||||
|
model: string,
|
||||||
|
): Promise<void> {
|
||||||
|
const timestamp = nowIsoString();
|
||||||
|
|
||||||
|
await executor
|
||||||
|
.deleteFrom('fact_embeddings')
|
||||||
|
.where('fact_id', '=', factId)
|
||||||
|
.where('model', '=', model)
|
||||||
|
.execute();
|
||||||
|
|
||||||
|
await executor
|
||||||
|
.insertInto('fact_embeddings')
|
||||||
|
.values({
|
||||||
|
fact_id: factId,
|
||||||
|
model,
|
||||||
|
dimensions: embedding.length,
|
||||||
|
embedding: serializeEmbedding(embedding),
|
||||||
|
content_hash: createContentHash(statement),
|
||||||
|
created_at: timestamp,
|
||||||
|
updated_at: timestamp,
|
||||||
|
})
|
||||||
|
.execute();
|
||||||
|
}
|
||||||
|
|
||||||
|
private assertEmbeddingShape(embedding: number[], expectedDimensions: number): void {
|
||||||
|
if (embedding.length !== expectedDimensions) {
|
||||||
|
throw new IdentityDBError(
|
||||||
|
`Embedding dimension mismatch. Expected ${expectedDimensions}, received ${embedding.length}.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private async upsertTopicInExecutor(
|
private async upsertTopicInExecutor(
|
||||||
executor: DatabaseExecutor,
|
executor: DatabaseExecutor,
|
||||||
input: UpsertTopicInput,
|
input: UpsertTopicInput,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import type { Kysely } from 'kysely';
|
|||||||
|
|
||||||
import {
|
import {
|
||||||
FACTS_TABLE,
|
FACTS_TABLE,
|
||||||
|
FACT_EMBEDDINGS_TABLE,
|
||||||
FACT_TOPICS_TABLE,
|
FACT_TOPICS_TABLE,
|
||||||
TOPIC_ALIASES_TABLE,
|
TOPIC_ALIASES_TABLE,
|
||||||
TOPIC_RELATIONS_TABLE,
|
TOPIC_RELATIONS_TABLE,
|
||||||
@@ -39,6 +40,21 @@ export async function initializeSchema(
|
|||||||
.addColumn('updated_at', 'text', (column) => column.notNull())
|
.addColumn('updated_at', 'text', (column) => column.notNull())
|
||||||
.execute();
|
.execute();
|
||||||
|
|
||||||
|
await db.schema
|
||||||
|
.createTable(FACT_EMBEDDINGS_TABLE)
|
||||||
|
.ifNotExists()
|
||||||
|
.addColumn('fact_id', 'text', (column) =>
|
||||||
|
column.notNull().references(`${FACTS_TABLE}.id`).onDelete('cascade'),
|
||||||
|
)
|
||||||
|
.addColumn('model', 'text', (column) => column.notNull())
|
||||||
|
.addColumn('dimensions', 'integer', (column) => column.notNull())
|
||||||
|
.addColumn('embedding', 'text', (column) => column.notNull())
|
||||||
|
.addColumn('content_hash', 'text', (column) => column.notNull())
|
||||||
|
.addColumn('created_at', 'text', (column) => column.notNull())
|
||||||
|
.addColumn('updated_at', 'text', (column) => column.notNull())
|
||||||
|
.addPrimaryKeyConstraint('fact_embeddings_pk', ['fact_id', 'model'])
|
||||||
|
.execute();
|
||||||
|
|
||||||
await db.schema
|
await db.schema
|
||||||
.createTable(FACT_TOPICS_TABLE)
|
.createTable(FACT_TOPICS_TABLE)
|
||||||
.ifNotExists()
|
.ifNotExists()
|
||||||
@@ -96,6 +112,13 @@ export async function initializeSchema(
|
|||||||
.column('fact_id')
|
.column('fact_id')
|
||||||
.execute();
|
.execute();
|
||||||
|
|
||||||
|
await db.schema
|
||||||
|
.createIndex('fact_embeddings_model_idx')
|
||||||
|
.ifNotExists()
|
||||||
|
.on(FACT_EMBEDDINGS_TABLE)
|
||||||
|
.column('model')
|
||||||
|
.execute();
|
||||||
|
|
||||||
await db.schema
|
await db.schema
|
||||||
.createIndex('topic_relations_parent_topic_id_idx')
|
.createIndex('topic_relations_parent_topic_id_idx')
|
||||||
.ifNotExists()
|
.ifNotExists()
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ export const FACTS_TABLE = 'facts';
|
|||||||
export const FACT_TOPICS_TABLE = 'fact_topics';
|
export const FACT_TOPICS_TABLE = 'fact_topics';
|
||||||
export const TOPIC_RELATIONS_TABLE = 'topic_relations';
|
export const TOPIC_RELATIONS_TABLE = 'topic_relations';
|
||||||
export const TOPIC_ALIASES_TABLE = 'topic_aliases';
|
export const TOPIC_ALIASES_TABLE = 'topic_aliases';
|
||||||
|
export const FACT_EMBEDDINGS_TABLE = 'fact_embeddings';
|
||||||
|
|
||||||
export const TOPIC_COLUMNS = [
|
export const TOPIC_COLUMNS = [
|
||||||
'id',
|
'id',
|
||||||
@@ -51,3 +52,13 @@ export const TOPIC_ALIAS_COLUMNS = [
|
|||||||
'created_at',
|
'created_at',
|
||||||
'updated_at',
|
'updated_at',
|
||||||
] as const;
|
] as const;
|
||||||
|
|
||||||
|
export const FACT_EMBEDDING_COLUMNS = [
|
||||||
|
'fact_id',
|
||||||
|
'model',
|
||||||
|
'dimensions',
|
||||||
|
'embedding',
|
||||||
|
'content_hash',
|
||||||
|
'created_at',
|
||||||
|
'updated_at',
|
||||||
|
] as const;
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { randomUUID } from 'node:crypto';
|
import { createHash, randomUUID } from 'node:crypto';
|
||||||
|
|
||||||
import type { Fact, FactTopic, Topic } from '../types/api';
|
import type { Fact, FactTopic, Topic } from '../types/api';
|
||||||
import type { FactRecord, TopicRecord } from '../types/domain';
|
import type { FactRecord, TopicRecord } from '../types/domain';
|
||||||
@@ -35,6 +35,42 @@ export function deserializeMetadata(metadata: string | null): unknown | null {
|
|||||||
return JSON.parse(metadata);
|
return JSON.parse(metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function serializeEmbedding(embedding: number[]): string {
|
||||||
|
return JSON.stringify(embedding);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function deserializeEmbedding(embedding: string): number[] {
|
||||||
|
return JSON.parse(embedding) as number[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createContentHash(input: string): string {
|
||||||
|
return createHash('sha256').update(input).digest('hex');
|
||||||
|
}
|
||||||
|
|
||||||
|
export function cosineSimilarity(left: number[], right: number[]): number {
|
||||||
|
if (left.length === 0 || left.length !== right.length) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let dot = 0;
|
||||||
|
let leftMagnitude = 0;
|
||||||
|
let rightMagnitude = 0;
|
||||||
|
|
||||||
|
for (let index = 0; index < left.length; index += 1) {
|
||||||
|
const leftValue = left[index] ?? 0;
|
||||||
|
const rightValue = right[index] ?? 0;
|
||||||
|
dot += leftValue * rightValue;
|
||||||
|
leftMagnitude += leftValue * leftValue;
|
||||||
|
rightMagnitude += rightValue * rightValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (leftMagnitude === 0 || rightMagnitude === 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return dot / (Math.sqrt(leftMagnitude) * Math.sqrt(rightMagnitude));
|
||||||
|
}
|
||||||
|
|
||||||
export function mapTopicRow(record: TopicRecord): Topic {
|
export function mapTopicRow(record: TopicRecord): Topic {
|
||||||
return {
|
return {
|
||||||
id: record.id,
|
id: record.id,
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
import type { AddFactInput, TopicLinkInput } from '../types/api';
|
import type {
|
||||||
|
AddFactInput,
|
||||||
|
EmbeddingProvider,
|
||||||
|
TopicLinkInput,
|
||||||
|
} from '../types/api';
|
||||||
|
|
||||||
export interface ExtractedFact {
|
export interface ExtractedFact {
|
||||||
statement?: string;
|
statement?: string;
|
||||||
@@ -15,4 +19,6 @@ export interface FactExtractor {
|
|||||||
|
|
||||||
export interface IngestStatementOptions {
|
export interface IngestStatementOptions {
|
||||||
extractor: FactExtractor;
|
extractor: FactExtractor;
|
||||||
|
embeddingProvider?: EmbeddingProvider;
|
||||||
|
duplicateThreshold?: number;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,3 +71,34 @@ export interface ListTopicsOptions {
|
|||||||
includeFacts?: boolean;
|
includeFacts?: boolean;
|
||||||
limit?: number;
|
limit?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface EmbeddingProvider {
|
||||||
|
model: string;
|
||||||
|
dimensions: number;
|
||||||
|
embed(input: string): Promise<number[]>;
|
||||||
|
embedMany?(inputs: string[]): Promise<number[][]>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface IndexFactEmbeddingsInput {
|
||||||
|
provider: EmbeddingProvider;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SearchFactsInput {
|
||||||
|
query: string;
|
||||||
|
provider: EmbeddingProvider;
|
||||||
|
topicNames?: string[];
|
||||||
|
limit?: number;
|
||||||
|
minimumScore?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface FindSimilarFactsInput {
|
||||||
|
statement: string;
|
||||||
|
provider: EmbeddingProvider;
|
||||||
|
topicNames?: string[];
|
||||||
|
limit?: number;
|
||||||
|
minimumScore?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ScoredFact extends Fact {
|
||||||
|
score: number;
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import type {
|
import type {
|
||||||
|
FactEmbeddingRecord,
|
||||||
FactRecord,
|
FactRecord,
|
||||||
FactTopicRecord,
|
FactTopicRecord,
|
||||||
TopicAliasRecord,
|
TopicAliasRecord,
|
||||||
@@ -12,4 +13,5 @@ export interface IdentityDatabaseSchema {
|
|||||||
fact_topics: FactTopicRecord;
|
fact_topics: FactTopicRecord;
|
||||||
topic_relations: TopicRelationRecord;
|
topic_relations: TopicRelationRecord;
|
||||||
topic_aliases: TopicAliasRecord;
|
topic_aliases: TopicAliasRecord;
|
||||||
|
fact_embeddings: FactEmbeddingRecord;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,3 +52,13 @@ export interface TopicAliasRecord {
|
|||||||
created_at: string;
|
created_at: string;
|
||||||
updated_at: string;
|
updated_at: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface FactEmbeddingRecord {
|
||||||
|
fact_id: string;
|
||||||
|
model: string;
|
||||||
|
dimensions: number;
|
||||||
|
embedding: string;
|
||||||
|
content_hash: string;
|
||||||
|
created_at: string;
|
||||||
|
updated_at: string;
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ afterEach(async () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
describe('initializeSchema', () => {
|
describe('initializeSchema', () => {
|
||||||
it('creates the topics, facts, fact_topics, topic_relations, and topic_aliases tables', async () => {
|
it('creates the topics, facts, fact_embeddings, fact_topics, topic_relations, and topic_aliases tables', async () => {
|
||||||
const connection = await createDatabase({ client: 'sqlite', filename: ':memory:' });
|
const connection = await createDatabase({ client: 'sqlite', filename: ':memory:' });
|
||||||
openConnections.push(connection.destroy);
|
openConnections.push(connection.destroy);
|
||||||
|
|
||||||
@@ -33,6 +33,7 @@ describe('initializeSchema', () => {
|
|||||||
|
|
||||||
expect(tableNames).toContain('topics');
|
expect(tableNames).toContain('topics');
|
||||||
expect(tableNames).toContain('facts');
|
expect(tableNames).toContain('facts');
|
||||||
|
expect(tableNames).toContain('fact_embeddings');
|
||||||
expect(tableNames).toContain('fact_topics');
|
expect(tableNames).toContain('fact_topics');
|
||||||
expect(tableNames).toContain('topic_relations');
|
expect(tableNames).toContain('topic_relations');
|
||||||
expect(tableNames).toContain('topic_aliases');
|
expect(tableNames).toContain('topic_aliases');
|
||||||
@@ -46,6 +47,7 @@ describe('initializeSchema', () => {
|
|||||||
|
|
||||||
const topicsColumns = await sql<{ name: string }>`PRAGMA table_info(topics)`.execute(connection.db);
|
const topicsColumns = await sql<{ name: string }>`PRAGMA table_info(topics)`.execute(connection.db);
|
||||||
const factsColumns = await sql<{ name: string }>`PRAGMA table_info(facts)`.execute(connection.db);
|
const factsColumns = await sql<{ name: string }>`PRAGMA table_info(facts)`.execute(connection.db);
|
||||||
|
const factEmbeddingsColumns = await sql<{ name: string }>`PRAGMA table_info(fact_embeddings)`.execute(connection.db);
|
||||||
const factTopicsColumns = await sql<{ name: string }>`PRAGMA table_info(fact_topics)`.execute(connection.db);
|
const factTopicsColumns = await sql<{ name: string }>`PRAGMA table_info(fact_topics)`.execute(connection.db);
|
||||||
const topicRelationsColumns = await sql<{ name: string }>`PRAGMA table_info(topic_relations)`.execute(connection.db);
|
const topicRelationsColumns = await sql<{ name: string }>`PRAGMA table_info(topic_relations)`.execute(connection.db);
|
||||||
const topicAliasesColumns = await sql<{ name: string }>`PRAGMA table_info(topic_aliases)`.execute(connection.db);
|
const topicAliasesColumns = await sql<{ name: string }>`PRAGMA table_info(topic_aliases)`.execute(connection.db);
|
||||||
@@ -73,6 +75,16 @@ describe('initializeSchema', () => {
|
|||||||
'updated_at',
|
'updated_at',
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
expect(factEmbeddingsColumns.rows.map((row) => row.name)).toEqual([
|
||||||
|
'fact_id',
|
||||||
|
'model',
|
||||||
|
'dimensions',
|
||||||
|
'embedding',
|
||||||
|
'content_hash',
|
||||||
|
'created_at',
|
||||||
|
'updated_at',
|
||||||
|
]);
|
||||||
|
|
||||||
expect(factTopicsColumns.rows.map((row) => row.name)).toEqual([
|
expect(factTopicsColumns.rows.map((row) => row.name)).toEqual([
|
||||||
'fact_id',
|
'fact_id',
|
||||||
'topic_id',
|
'topic_id',
|
||||||
|
|||||||
170
tests/semantic-search.test.ts
Normal file
170
tests/semantic-search.test.ts
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
|
||||||
|
|
||||||
|
import { IdentityDB } from '../src/core/identity-db';
|
||||||
|
import type { FactExtractor } from '../src/ingestion/types';
|
||||||
|
import type { EmbeddingProvider } from '../src/types/api';
|
||||||
|
|
||||||
|
class FakeEmbeddingProvider implements EmbeddingProvider {
|
||||||
|
model = 'fake-semantic-v1';
|
||||||
|
dimensions = 3;
|
||||||
|
|
||||||
|
async embed(input: string): Promise<number[]> {
|
||||||
|
return embeddingFor(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
async embedMany(inputs: string[]): Promise<number[][]> {
|
||||||
|
return Promise.all(inputs.map((input) => this.embed(input)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function embeddingFor(input: string): number[] {
|
||||||
|
const normalized = input.toLowerCase();
|
||||||
|
|
||||||
|
if (normalized.includes('bun') && normalized.includes('typescript')) {
|
||||||
|
return [1, 0, 0];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normalized.includes('tooling') || normalized.includes('runtime')) {
|
||||||
|
return [0.98, 0.02, 0];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normalized.includes('typescript')) {
|
||||||
|
return [0.9, 0.1, 0];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normalized.includes('python')) {
|
||||||
|
return [0, 1, 0];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normalized.includes('database')) {
|
||||||
|
return [0, 0.2, 0.8];
|
||||||
|
}
|
||||||
|
|
||||||
|
return [0.1, 0.1, 0.1];
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('IdentityDB semantic search', () => {
|
||||||
|
let db: IdentityDB;
|
||||||
|
let provider: FakeEmbeddingProvider;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
provider = new FakeEmbeddingProvider();
|
||||||
|
db = await IdentityDB.connect({ client: 'sqlite', filename: ':memory:' });
|
||||||
|
await db.initialize();
|
||||||
|
|
||||||
|
await db.addFact({
|
||||||
|
statement: 'Bun runs TypeScript tooling quickly.',
|
||||||
|
topics: [
|
||||||
|
{ name: 'Bun', category: 'entity', granularity: 'concrete' },
|
||||||
|
{ name: 'TypeScript', category: 'entity', granularity: 'concrete' },
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
await db.addFact({
|
||||||
|
statement: 'TypeScript compiles to JavaScript.',
|
||||||
|
topics: [
|
||||||
|
{ name: 'TypeScript', category: 'entity', granularity: 'concrete' },
|
||||||
|
{ name: 'JavaScript', category: 'entity', granularity: 'concrete' },
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
await db.addFact({
|
||||||
|
statement: 'Python uses indentation syntax.',
|
||||||
|
topics: [
|
||||||
|
{ name: 'Python', category: 'entity', granularity: 'concrete' },
|
||||||
|
],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(async () => {
|
||||||
|
await db.close();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('indexes facts and returns semantic search matches ordered by score', async () => {
|
||||||
|
await db.indexFactEmbeddings({ provider });
|
||||||
|
|
||||||
|
const matches = await db.searchFacts({
|
||||||
|
query: 'TypeScript runtime tooling',
|
||||||
|
provider,
|
||||||
|
limit: 2,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(matches).toHaveLength(2);
|
||||||
|
expect(matches[0]?.statement).toBe('Bun runs TypeScript tooling quickly.');
|
||||||
|
expect(matches[1]?.statement).toBe('TypeScript compiles to JavaScript.');
|
||||||
|
expect(matches[0]!.score).toBeGreaterThan(matches[1]!.score);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('filters semantic search candidates by topic names', async () => {
|
||||||
|
await db.indexFactEmbeddings({ provider });
|
||||||
|
|
||||||
|
const matches = await db.searchFacts({
|
||||||
|
query: 'TypeScript runtime tooling',
|
||||||
|
provider,
|
||||||
|
topicNames: ['Python'],
|
||||||
|
limit: 5,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(matches.map((match) => match.statement)).toEqual(['Python uses indentation syntax.']);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('finds similar facts from an input statement', async () => {
|
||||||
|
await db.indexFactEmbeddings({ provider });
|
||||||
|
|
||||||
|
const matches = await db.findSimilarFacts({
|
||||||
|
statement: 'Bun makes TypeScript tooling fast.',
|
||||||
|
provider,
|
||||||
|
limit: 2,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(matches[0]?.statement).toBe('Bun runs TypeScript tooling quickly.');
|
||||||
|
expect(matches[0]!.score).toBeGreaterThan(matches[1]!.score);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('IdentityDB dedup-aware ingestion', () => {
|
||||||
|
let db: IdentityDB;
|
||||||
|
let provider: FakeEmbeddingProvider;
|
||||||
|
let extractor: FactExtractor;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
provider = new FakeEmbeddingProvider();
|
||||||
|
extractor = {
|
||||||
|
async extract(input) {
|
||||||
|
return {
|
||||||
|
statement: input,
|
||||||
|
topics: [
|
||||||
|
{ name: 'Bun', category: 'entity', granularity: 'concrete' },
|
||||||
|
{ name: 'TypeScript', category: 'entity', granularity: 'concrete' },
|
||||||
|
],
|
||||||
|
};
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
db = await IdentityDB.connect({ client: 'sqlite', filename: ':memory:' });
|
||||||
|
await db.initialize();
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(async () => {
|
||||||
|
await db.close();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns the existing fact when ingestion detects a semantic duplicate', async () => {
|
||||||
|
const first = await db.ingestStatement('Bun runs TypeScript tooling quickly.', {
|
||||||
|
extractor,
|
||||||
|
embeddingProvider: provider,
|
||||||
|
});
|
||||||
|
|
||||||
|
const second = await db.ingestStatement('Bun makes TypeScript tooling fast.', {
|
||||||
|
extractor,
|
||||||
|
embeddingProvider: provider,
|
||||||
|
duplicateThreshold: 0.95,
|
||||||
|
});
|
||||||
|
|
||||||
|
const facts = await db.getTopicFacts('TypeScript');
|
||||||
|
|
||||||
|
expect(second.id).toBe(first.id);
|
||||||
|
expect(facts).toHaveLength(1);
|
||||||
|
expect(facts[0]?.statement).toBe('Bun runs TypeScript tooling quickly.');
|
||||||
|
});
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user