Files
pgvecterAPI/main.go
2026-02-09 18:31:30 +09:00

2242 lines
67 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"math"
"net"
"net/http"
"os"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type ctxKey string
const (
ctxPermissionsUsers ctxKey = "permissions_users"
ctxAPIKeyID ctxKey = "api_key_id"
)
var ErrNotFound = errors.New("not found")
type ErrorResponse struct {
Error string `json:"error"`
Code string `json:"code"`
}
type HealthResponse struct {
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
}
type KbSearchRequest struct {
Query string `json:"query"`
QueryEmbedding []float64 `json:"query_embedding"`
TopK int `json:"top_k"`
Filter map[string]interface{} `json:"filter"`
Collection string `json:"collection"`
Collections []string `json:"collections"`
}
type KbSearchResponse struct {
Results []KbSearchResult `json:"results"`
}
type KbSearchResult struct {
ID string `json:"id"`
Collection string `json:"collection"`
Content string `json:"content"`
Metadata map[string]interface{} `json:"metadata"`
Score float64 `json:"score"`
}
type CollectionItem struct {
Name string `json:"name"`
Count int `json:"count"`
}
type CollectionsResponse struct {
Items []CollectionItem `json:"items"`
}
type PermissionsAuditItem struct {
ID string `json:"id"`
APIKeyID string `json:"api_key_id"`
PermissionsUsers []string `json:"permissions_users"`
RequestedAt time.Time `json:"requested_at"`
Endpoint string `json:"endpoint"`
RemoteAddr string `json:"remote_addr,omitempty"`
UserAgent string `json:"user_agent,omitempty"`
ProvidedPermissions map[string]interface{} `json:"provided_permissions"`
ProvidedMetadata map[string]interface{} `json:"provided_metadata"`
}
type PermissionsAuditResponse struct {
Items []PermissionsAuditItem `json:"items"`
}
type KbUpsertRequest struct {
ID string `json:"id"`
DocID string `json:"doc_id"`
ChunkIndex int `json:"chunk_index"`
Collection string `json:"collection"`
Content string `json:"content"`
Metadata map[string]interface{} `json:"metadata"`
Embedding []float64 `json:"embedding"`
AutoChunk bool `json:"auto_chunk"`
ChunkSize int `json:"chunk_size"`
Overlap int `json:"chunk_overlap"`
}
type KbUpsertResponse struct {
ID string `json:"id"`
DocID string `json:"doc_id"`
Updated bool `json:"updated"`
Chunks []struct {
ID string `json:"id"`
ChunkIndex int `json:"chunk_index"`
} `json:"chunks,omitempty"`
}
type KbDeleteRequest struct {
ID string `json:"id"`
DocID string `json:"doc_id"`
Collection string `json:"collection"`
Filter map[string]interface{} `json:"filter"`
DryRun bool `json:"dry_run"`
}
type KbDeleteResponse struct {
Deleted int64 `json:"deleted"`
}
type AdminApiKey struct {
ID string `json:"id"`
Label string `json:"label"`
PermissionsUsers []string `json:"permissions_users"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
}
type AdminApiKeyListResponse struct {
Items []AdminApiKey `json:"items"`
}
type AdminApiKeyCreateRequest struct {
Label string `json:"label"`
PermissionsUsers []string `json:"permissions_users"`
}
type AdminApiKeyCreateResponse struct {
APIKey string `json:"api_key"`
Key AdminApiKey `json:"key"`
}
type AdminApiKeyUpdateRequest struct {
Label *string `json:"label"`
PermissionsUsers []string `json:"permissions_users"`
}
type APIKeyStore interface {
Create(ctx context.Context, label string, permissions []string) (string, AdminApiKey, error)
List(ctx context.Context) ([]AdminApiKey, error)
Update(ctx context.Context, id string, label *string, permissions []string) (AdminApiKey, error)
Revoke(ctx context.Context, id string) (AdminApiKey, error)
Authenticate(ctx context.Context, apiKey string) (*AdminApiKey, bool, error)
}
type MemoryKeyStore struct {
mu sync.Mutex
keys map[string]*AdminApiKey
hashToID map[string]string
}
func NewMemoryKeyStore() *MemoryKeyStore {
return &MemoryKeyStore{
keys: make(map[string]*AdminApiKey),
hashToID: make(map[string]string),
}
}
func (s *MemoryKeyStore) Create(_ context.Context, label string, permissions []string) (string, AdminApiKey, error) {
if strings.TrimSpace(label) == "" {
return "", AdminApiKey{}, errors.New("label is required")
}
if len(permissions) == 0 {
return "", AdminApiKey{}, errors.New("permissions_users is required")
}
rawKey, err := generateAPIKey()
if err != nil {
return "", AdminApiKey{}, err
}
keyHash := hashKey(rawKey)
now := time.Now().UTC()
s.mu.Lock()
defer s.mu.Unlock()
id := generateID()
apiKey := AdminApiKey{
ID: id,
Label: label,
PermissionsUsers: append([]string(nil), permissions...),
Status: "active",
CreatedAt: now,
LastUsedAt: nil,
}
s.keys[id] = &apiKey
s.hashToID[keyHash] = id
return rawKey, apiKey, nil
}
func (s *MemoryKeyStore) List(_ context.Context) ([]AdminApiKey, error) {
s.mu.Lock()
defer s.mu.Unlock()
items := make([]AdminApiKey, 0, len(s.keys))
for _, k := range s.keys {
items = append(items, *k)
}
sort.Slice(items, func(i, j int) bool {
return items[i].CreatedAt.Before(items[j].CreatedAt)
})
return items, nil
}
func (s *MemoryKeyStore) Update(_ context.Context, id string, label *string, permissions []string) (AdminApiKey, error) {
s.mu.Lock()
defer s.mu.Unlock()
k, ok := s.keys[id]
if !ok {
return AdminApiKey{}, ErrNotFound
}
if label != nil {
if strings.TrimSpace(*label) == "" {
return AdminApiKey{}, errors.New("label is required")
}
k.Label = *label
}
if permissions != nil {
if len(permissions) == 0 {
return AdminApiKey{}, errors.New("permissions_users is required")
}
k.PermissionsUsers = append([]string(nil), permissions...)
}
return *k, nil
}
func (s *MemoryKeyStore) Revoke(_ context.Context, id string) (AdminApiKey, error) {
s.mu.Lock()
defer s.mu.Unlock()
k, ok := s.keys[id]
if !ok {
return AdminApiKey{}, ErrNotFound
}
k.Status = "revoked"
return *k, nil
}
func (s *MemoryKeyStore) Authenticate(_ context.Context, apiKey string) (*AdminApiKey, bool, error) {
keyHash := hashKey(apiKey)
s.mu.Lock()
defer s.mu.Unlock()
id, ok := s.hashToID[keyHash]
if !ok {
return nil, false, nil
}
k, ok := s.keys[id]
if !ok || k.Status != "active" {
return nil, false, nil
}
now := time.Now().UTC()
k.LastUsedAt = &now
return k, true, nil
}
type PostgresKeyStore struct {
pool *pgxpool.Pool
}
func NewPostgresKeyStore(pool *pgxpool.Pool) *PostgresKeyStore {
return &PostgresKeyStore{pool: pool}
}
func (s *PostgresKeyStore) EnsureSchema(ctx context.Context) error {
schema := `
CREATE TABLE IF NOT EXISTS api_keys (
id text PRIMARY KEY,
label text NOT NULL,
api_key_hash text NOT NULL UNIQUE,
permissions_users text[] NOT NULL,
status text NOT NULL,
created_at timestamptz NOT NULL,
last_used_at timestamptz
);
`
_, err := s.pool.Exec(ctx, schema)
return err
}
func (s *PostgresKeyStore) Create(ctx context.Context, label string, permissions []string) (string, AdminApiKey, error) {
if strings.TrimSpace(label) == "" {
return "", AdminApiKey{}, errors.New("label is required")
}
if len(permissions) == 0 {
return "", AdminApiKey{}, errors.New("permissions_users is required")
}
rawKey, err := generateAPIKey()
if err != nil {
return "", AdminApiKey{}, err
}
keyHash := hashKey(rawKey)
now := time.Now().UTC()
id := generateID()
var key AdminApiKey
row := s.pool.QueryRow(ctx, `
INSERT INTO api_keys (id, label, api_key_hash, permissions_users, status, created_at)
VALUES ($1, $2, $3, $4, 'active', $5)
RETURNING id, label, permissions_users, status, created_at, last_used_at
`, id, label, keyHash, permissions, now)
if err := row.Scan(&key.ID, &key.Label, &key.PermissionsUsers, &key.Status, &key.CreatedAt, &key.LastUsedAt); err != nil {
return "", AdminApiKey{}, err
}
return rawKey, key, nil
}
func (s *PostgresKeyStore) List(ctx context.Context) ([]AdminApiKey, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, label, permissions_users, status, created_at, last_used_at
FROM api_keys
ORDER BY created_at ASC
`)
if err != nil {
return nil, err
}
defer rows.Close()
var items []AdminApiKey
for rows.Next() {
var k AdminApiKey
if err := rows.Scan(&k.ID, &k.Label, &k.PermissionsUsers, &k.Status, &k.CreatedAt, &k.LastUsedAt); err != nil {
return nil, err
}
items = append(items, k)
}
return items, rows.Err()
}
func (s *PostgresKeyStore) Update(ctx context.Context, id string, label *string, permissions []string) (AdminApiKey, error) {
if label == nil && permissions == nil {
return AdminApiKey{}, errors.New("no fields to update")
}
if label != nil && strings.TrimSpace(*label) == "" {
return AdminApiKey{}, errors.New("label is required")
}
if permissions != nil && len(permissions) == 0 {
return AdminApiKey{}, errors.New("permissions_users is required")
}
var key AdminApiKey
row := s.pool.QueryRow(ctx, `
UPDATE api_keys
SET label = COALESCE($2, label),
permissions_users = COALESCE($3, permissions_users)
WHERE id = $1
RETURNING id, label, permissions_users, status, created_at, last_used_at
`, id, label, permissions)
if err := row.Scan(&key.ID, &key.Label, &key.PermissionsUsers, &key.Status, &key.CreatedAt, &key.LastUsedAt); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return AdminApiKey{}, ErrNotFound
}
return AdminApiKey{}, err
}
return key, nil
}
func (s *PostgresKeyStore) Revoke(ctx context.Context, id string) (AdminApiKey, error) {
var key AdminApiKey
row := s.pool.QueryRow(ctx, `
UPDATE api_keys
SET status = 'revoked'
WHERE id = $1
RETURNING id, label, permissions_users, status, created_at, last_used_at
`, id)
if err := row.Scan(&key.ID, &key.Label, &key.PermissionsUsers, &key.Status, &key.CreatedAt, &key.LastUsedAt); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return AdminApiKey{}, ErrNotFound
}
return AdminApiKey{}, err
}
return key, nil
}
func (s *PostgresKeyStore) Authenticate(ctx context.Context, apiKey string) (*AdminApiKey, bool, error) {
keyHash := hashKey(apiKey)
var key AdminApiKey
err := s.pool.QueryRow(ctx, `
SELECT id, label, permissions_users, status, created_at, last_used_at
FROM api_keys
WHERE api_key_hash = $1
`, keyHash).Scan(&key.ID, &key.Label, &key.PermissionsUsers, &key.Status, &key.CreatedAt, &key.LastUsedAt)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, false, nil
}
return nil, false, err
}
if key.Status != "active" {
return nil, false, nil
}
now := time.Now().UTC()
_, err = s.pool.Exec(ctx, `UPDATE api_keys SET last_used_at = $2 WHERE id = $1`, key.ID, now)
if err != nil {
return nil, false, err
}
key.LastUsedAt = &now
return &key, true, nil
}
func generateAPIKey() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return "pgv_" + base64.RawURLEncoding.EncodeToString(buf), nil
}
func generateID() string {
buf := make([]byte, 16)
_, _ = rand.Read(buf)
// format as UUID v4-like
buf[6] = (buf[6] & 0x0f) | 0x40
buf[8] = (buf[8] & 0x3f) | 0x80
return hex.EncodeToString(buf[0:4]) + "-" +
hex.EncodeToString(buf[4:6]) + "-" +
hex.EncodeToString(buf[6:8]) + "-" +
hex.EncodeToString(buf[8:10]) + "-" +
hex.EncodeToString(buf[10:16])
}
func parseUUID(value string) (string, error) {
parts := strings.Split(value, "-")
if len(parts) != 5 {
return "", errors.New("invalid uuid format")
}
lengths := []int{8, 4, 4, 4, 12}
for i, part := range parts {
if len(part) != lengths[i] {
return "", errors.New("invalid uuid format")
}
if _, err := hex.DecodeString(part); err != nil {
return "", errors.New("invalid uuid format")
}
}
return value, nil
}
func hashKey(raw string) string {
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
func withAdminAuth(adminKey string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if adminKey == "" {
writeError(w, http.StatusServiceUnavailable, "admin_key_not_configured", "ADMIN_API_KEY is not set")
return
}
key := r.Header.Get("X-ADMIN-API-KEY")
if key == "" || key != adminKey {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid admin api key")
return
}
next.ServeHTTP(w, r)
})
}
func withAPIKeyAuth(store APIKeyStore, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := r.Header.Get("X-API-KEY")
if key == "" {
writeError(w, http.StatusUnauthorized, "unauthorized", "missing api key")
return
}
k, ok, err := store.Authenticate(r.Context(), key)
if err != nil {
writeError(w, http.StatusInternalServerError, "auth_failed", "failed to authenticate api key")
return
}
if !ok || k == nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid api key")
return
}
ctx := context.WithValue(r.Context(), ctxPermissionsUsers, k.PermissionsUsers)
ctx = context.WithValue(ctx, ctxAPIKeyID, k.ID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func handleHealth(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, HealthResponse{
Status: "ok",
Timestamp: time.Now().UTC(),
})
}
func handleNotImplemented(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusNotImplemented, "not_implemented", "endpoint not implemented yet")
}
func handleKbUpsert(db *pgxpool.Pool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if db == nil {
writeError(w, http.StatusServiceUnavailable, "db_not_configured", "DATABASE_URL is not set")
return
}
permissions, ok := r.Context().Value(ctxPermissionsUsers).([]string)
if !ok || len(permissions) == 0 {
writeError(w, http.StatusForbidden, "forbidden", "permissions.users not set")
return
}
var req KbUpsertRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
if hasMetadataPermissions(req.Metadata) {
apiKeyID, _ := r.Context().Value(ctxAPIKeyID).(string)
permissions, _ := r.Context().Value(ctxPermissionsUsers).([]string)
if err := recordPermissionsAudit(r.Context(), db, apiKeyID, permissions, r, req.Metadata["permissions"], req.Metadata); err != nil {
writeError(w, http.StatusInternalServerError, "audit_failed", err.Error())
return
}
writeError(w, http.StatusBadRequest, "invalid_request", "metadata.permissions is not allowed")
return
}
if strings.TrimSpace(req.Content) == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "content is required")
return
}
if strings.TrimSpace(req.Collection) == "" {
req.Collection = "default"
}
if req.ChunkSize == 0 {
req.ChunkSize = 800
}
if req.Overlap == 0 {
req.Overlap = 100
}
if req.ChunkSize < 200 {
writeError(w, http.StatusBadRequest, "invalid_request", "chunk_size is too small")
return
}
if req.Overlap < 0 || req.Overlap >= req.ChunkSize {
writeError(w, http.StatusBadRequest, "invalid_request", "chunk_overlap must be between 0 and chunk_size-1")
return
}
dim := embeddingDim()
if len(req.Embedding) > 0 && req.AutoChunk {
writeError(w, http.StatusBadRequest, "invalid_request", "embedding cannot be provided when auto_chunk is true")
return
}
if strings.TrimSpace(req.DocID) == "" {
req.DocID = generateID()
} else {
if _, err := parseUUID(req.DocID); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "doc_id must be uuid")
return
}
}
if req.ChunkIndex < 0 {
writeError(w, http.StatusBadRequest, "invalid_request", "chunk_index must be >= 0")
return
}
if !req.AutoChunk {
if strings.TrimSpace(req.ID) == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "id is required")
return
}
if _, err := parseUUID(req.ID); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "id must be uuid")
return
}
}
metadata := req.Metadata
if metadata == nil {
metadata = map[string]interface{}{}
}
metadata["collection"] = req.Collection
perms := map[string]interface{}{
"users": permissions,
}
metadata["permissions"] = perms
metaBytes, err := json.Marshal(metadata)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "metadata must be valid object")
return
}
if req.AutoChunk || len([]rune(req.Content)) > req.ChunkSize {
chunks := splitIntoChunks(req.Content, req.ChunkSize, req.Overlap)
resp := KbUpsertResponse{DocID: req.DocID, Updated: false}
for i, chunk := range chunks {
emb, err := createEmbeddingWithDim(r.Context(), chunk, dim)
if err != nil {
writeError(w, http.StatusBadRequest, "embedding_failed", err.Error())
return
}
vectorParam := vectorLiteral(emb)
id := generateID()
var updated bool
err = db.QueryRow(r.Context(), `
INSERT INTO kb_doc_chunks (id, doc_id, chunk_index, collection, content, metadata, embedding, updated_at)
VALUES ($1, $2, $3, $4, $5, $6::jsonb, $7::vector, now())
ON CONFLICT (id) DO UPDATE
SET doc_id = EXCLUDED.doc_id,
chunk_index = EXCLUDED.chunk_index,
collection = EXCLUDED.collection,
content = EXCLUDED.content,
metadata = EXCLUDED.metadata,
embedding = EXCLUDED.embedding,
updated_at = now()
RETURNING (xmax <> 0) AS updated
`, id, req.DocID, i, req.Collection, chunk, string(metaBytes), vectorParam).Scan(&updated)
if err != nil {
log.Printf("kb.upsert failed: %v", err)
writeError(w, http.StatusInternalServerError, "upsert_failed", fmt.Sprintf("failed to upsert document: %v", err))
return
}
resp.Updated = resp.Updated || updated
resp.Chunks = append(resp.Chunks, struct {
ID string `json:"id"`
ChunkIndex int `json:"chunk_index"`
}{ID: id, ChunkIndex: i})
}
if len(resp.Chunks) > 0 {
resp.ID = resp.Chunks[0].ID
}
writeJSON(w, http.StatusOK, resp)
return
}
if len(req.Embedding) == 0 {
emb, err := createEmbeddingWithDim(r.Context(), req.Content, dim)
if err != nil {
writeError(w, http.StatusBadRequest, "embedding_failed", err.Error())
return
}
req.Embedding = emb
}
if len(req.Embedding) != dim {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("embedding must be length %d", dim))
return
}
vectorParam := vectorLiteral(req.Embedding)
var updated bool
err = db.QueryRow(r.Context(), `
INSERT INTO kb_doc_chunks (id, doc_id, chunk_index, collection, content, metadata, embedding, updated_at)
VALUES ($1, $2, $3, $4, $5, $6::jsonb, $7::vector, now())
ON CONFLICT (id) DO UPDATE
SET doc_id = EXCLUDED.doc_id,
chunk_index = EXCLUDED.chunk_index,
collection = EXCLUDED.collection,
content = EXCLUDED.content,
metadata = EXCLUDED.metadata,
embedding = EXCLUDED.embedding,
updated_at = now()
RETURNING (xmax <> 0) AS updated
`, req.ID, req.DocID, req.ChunkIndex, req.Collection, req.Content, string(metaBytes), vectorParam).Scan(&updated)
if err != nil {
log.Printf("kb.upsert failed: %v", err)
writeError(w, http.StatusInternalServerError, "upsert_failed", fmt.Sprintf("failed to upsert document: %v", err))
return
}
writeJSON(w, http.StatusOK, KbUpsertResponse{ID: req.ID, DocID: req.DocID, Updated: updated})
}
}
func handleKbSearch(db *pgxpool.Pool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if db == nil {
writeError(w, http.StatusServiceUnavailable, "db_not_configured", "DATABASE_URL is not set")
return
}
logConnectionInfo(db)
var req KbSearchRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
if strings.TrimSpace(req.Query) == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "query is required")
return
}
if req.TopK == 0 {
req.TopK = 5
}
if req.TopK < 1 || req.TopK > 50 {
writeError(w, http.StatusBadRequest, "invalid_request", "top_k must be between 1 and 50")
return
}
if strings.TrimSpace(req.Collection) == "" && len(req.Collections) == 0 {
writeError(w, http.StatusBadRequest, "invalid_request", "collection or collections is required")
return
}
if strings.TrimSpace(req.Collection) != "" && len(req.Collections) > 0 {
writeError(w, http.StatusBadRequest, "invalid_request", "collection and collections are mutually exclusive")
return
}
if strings.TrimSpace(req.Collection) != "" {
if err := validateCollection(req.Collection); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "collection is invalid")
return
}
}
if len(req.Collections) > 0 {
for _, name := range req.Collections {
if err := validateCollection(name); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "collections contains invalid value")
return
}
}
}
dim := embeddingDim()
if len(req.QueryEmbedding) == 0 {
emb, err := createEmbeddingWithDim(r.Context(), req.Query, dim)
if err != nil {
writeError(w, http.StatusBadRequest, "embedding_failed", err.Error())
return
}
req.QueryEmbedding = emb
}
if len(req.QueryEmbedding) != dim {
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("query_embedding must be length %d", dim))
return
}
permissions, ok := r.Context().Value(ctxPermissionsUsers).([]string)
if !ok || len(permissions) == 0 {
writeError(w, http.StatusForbidden, "forbidden", "permissions.users not set")
return
}
permFilter := map[string]interface{}{
"contains": map[string]interface{}{
"metadata.permissions.users": permissions,
},
}
mergedFilter := permFilter
if req.Filter != nil {
if filterHasMetadataPermissions(req.Filter) {
apiKeyID, _ := r.Context().Value(ctxAPIKeyID).(string)
permissions, _ := r.Context().Value(ctxPermissionsUsers).([]string)
if err := recordPermissionsAudit(r.Context(), db, apiKeyID, permissions, r, map[string]interface{}{"filter": req.Filter}, map[string]interface{}{"filter": req.Filter}); err != nil {
writeError(w, http.StatusInternalServerError, "audit_failed", err.Error())
return
}
writeError(w, http.StatusBadRequest, "invalid_request", "metadata.permissions filter is not allowed")
return
}
mergedFilter = map[string]interface{}{
"and": []interface{}{req.Filter, permFilter},
}
}
whereSQL, params, nextIdx, err := buildFilterSQL(mergedFilter, 3)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
vectorParam := vectorLiteral(req.QueryEmbedding)
if strings.TrimSpace(req.Collection) != "" {
params = append([]interface{}{vectorParam, req.Collection}, params...)
} else {
params = append([]interface{}{vectorParam, req.Collections}, params...)
}
params = append(params, req.TopK)
limitIdx := nextIdx
var query string
if strings.TrimSpace(req.Collection) != "" {
query = fmt.Sprintf(`
SELECT id::text, collection, content, metadata, (embedding <=> $1) AS score
FROM kb_doc_chunks
WHERE collection = $2 AND %s
ORDER BY embedding <=> $1
LIMIT $%d
`, whereSQL, limitIdx)
} else {
query = fmt.Sprintf(`
SELECT id::text, collection, content, metadata, (embedding <=> $1) AS score
FROM kb_doc_chunks
WHERE collection = ANY($2) AND %s
ORDER BY embedding <=> $1
LIMIT $%d
`, whereSQL, limitIdx)
}
rows, err := db.Query(r.Context(), query, params...)
if err != nil {
log.Printf("kb.search query failed: %v", err)
writeError(w, http.StatusInternalServerError, "search_failed", fmt.Sprintf("failed to execute search: %v", err))
return
}
defer rows.Close()
results := make([]KbSearchResult, 0)
for rows.Next() {
var res KbSearchResult
var metaBytes []byte
if err := rows.Scan(&res.ID, &res.Collection, &res.Content, &metaBytes, &res.Score); err != nil {
log.Printf("kb.search scan failed: %v", err)
writeError(w, http.StatusInternalServerError, "search_failed", fmt.Sprintf("failed to read search results: %v", err))
return
}
if math.IsNaN(res.Score) || math.IsInf(res.Score, 0) {
res.Score = 0
} else {
res.Score = distanceToSimilarity(res.Score)
}
if len(metaBytes) > 0 {
if err := json.Unmarshal(metaBytes, &res.Metadata); err != nil {
log.Printf("kb.search metadata unmarshal failed: %v", err)
writeError(w, http.StatusInternalServerError, "search_failed", fmt.Sprintf("failed to parse search results: %v", err))
return
}
}
results = append(results, res)
}
if err := rows.Err(); err != nil {
log.Printf("kb.search rows failed: %v", err)
writeError(w, http.StatusInternalServerError, "search_failed", "failed to read search results")
return
}
writeJSON(w, http.StatusOK, KbSearchResponse{Results: results})
}
}
func handleKbDelete(db *pgxpool.Pool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if db == nil {
writeError(w, http.StatusServiceUnavailable, "db_not_configured", "DATABASE_URL is not set")
return
}
permissions, ok := r.Context().Value(ctxPermissionsUsers).([]string)
if !ok || len(permissions) == 0 {
writeError(w, http.StatusForbidden, "forbidden", "permissions.users not set")
return
}
apiKeyID, _ := r.Context().Value(ctxAPIKeyID).(string)
var req KbDeleteRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
if strings.TrimSpace(req.ID) == "" && req.Filter == nil {
writeError(w, http.StatusBadRequest, "invalid_request", "id or filter is required")
return
}
if strings.TrimSpace(req.ID) != "" {
if _, err := parseUUID(req.ID); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "id must be uuid")
return
}
}
if strings.TrimSpace(req.DocID) != "" {
if _, err := parseUUID(req.DocID); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "doc_id must be uuid")
return
}
}
if strings.TrimSpace(req.Collection) != "" {
if err := validateCollection(req.Collection); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "collection is invalid")
return
}
}
permFilter := map[string]interface{}{
"contains": map[string]interface{}{
"metadata.permissions.users": permissions,
},
}
mergedFilter := permFilter
if req.Filter != nil {
if filterHasMetadataPermissions(req.Filter) {
apiKeyID, _ := r.Context().Value(ctxAPIKeyID).(string)
permissions, _ := r.Context().Value(ctxPermissionsUsers).([]string)
if err := recordPermissionsAudit(r.Context(), db, apiKeyID, permissions, r, map[string]interface{}{"filter": req.Filter}, map[string]interface{}{"filter": req.Filter}); err != nil {
writeError(w, http.StatusInternalServerError, "audit_failed", err.Error())
return
}
writeError(w, http.StatusBadRequest, "invalid_request", "metadata.permissions filter is not allowed")
return
}
mergedFilter = map[string]interface{}{
"and": []interface{}{req.Filter, permFilter},
}
}
clauses := make([]string, 0, 4)
params := make([]interface{}, 0, 4)
idx := 1
if strings.TrimSpace(req.ID) != "" {
clauses = append(clauses, fmt.Sprintf("id = $%d", idx))
params = append(params, req.ID)
idx++
}
if strings.TrimSpace(req.DocID) != "" {
clauses = append(clauses, fmt.Sprintf("doc_id = $%d", idx))
params = append(params, req.DocID)
idx++
}
if strings.TrimSpace(req.Collection) != "" {
clauses = append(clauses, fmt.Sprintf("collection = $%d", idx))
params = append(params, req.Collection)
idx++
}
whereSQL, filterParams, _, err := buildFilterSQL(mergedFilter, idx)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
clauses = append(clauses, whereSQL)
params = append(params, filterParams...)
if req.DryRun {
query := fmt.Sprintf("SELECT COUNT(*) FROM kb_doc_chunks WHERE %s", strings.Join(clauses, " AND "))
var count int64
if err := db.QueryRow(r.Context(), query, params...).Scan(&count); err != nil {
log.Printf("kb.delete dry-run failed: %v", err)
writeError(w, http.StatusInternalServerError, "delete_failed", fmt.Sprintf("failed to count: %v", err))
return
}
if err := recordDeleteAudit(r.Context(), db, apiKeyID, permissions, req, count); err != nil {
writeError(w, http.StatusInternalServerError, "audit_failed", err.Error())
return
}
writeJSON(w, http.StatusOK, KbDeleteResponse{Deleted: count})
return
}
query := fmt.Sprintf("DELETE FROM kb_doc_chunks WHERE %s", strings.Join(clauses, " AND "))
ct, err := db.Exec(r.Context(), query, params...)
if err != nil {
log.Printf("kb.delete failed: %v", err)
writeError(w, http.StatusInternalServerError, "delete_failed", fmt.Sprintf("failed to delete: %v", err))
return
}
if err := recordDeleteAudit(r.Context(), db, apiKeyID, permissions, req, ct.RowsAffected()); err != nil {
writeError(w, http.StatusInternalServerError, "audit_failed", err.Error())
return
}
writeJSON(w, http.StatusOK, KbDeleteResponse{Deleted: ct.RowsAffected()})
}
}
func recordDeleteAudit(ctx context.Context, db *pgxpool.Pool, apiKeyID string, permissions []string, req KbDeleteRequest, deleted int64) error {
var filterJSON any
if req.Filter != nil {
raw, err := json.Marshal(req.Filter)
if err != nil {
return fmt.Errorf("invalid filter for audit: %w", err)
}
filterJSON = string(raw)
}
var idParam any
if strings.TrimSpace(req.ID) != "" {
idParam = req.ID
}
var docIDParam any
if strings.TrimSpace(req.DocID) != "" {
docIDParam = req.DocID
}
_, err := db.Exec(ctx, `
INSERT INTO kb_delete_audit (
id, api_key_id, permissions_users, dry_run, deleted_count, target_id, target_doc_id, collection, filter
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`, generateID(), apiKeyID, permissions, req.DryRun, deleted, idParam, docIDParam, req.Collection, filterJSON)
if err != nil {
return fmt.Errorf("failed to write delete audit: %w", err)
}
return nil
}
func recordPermissionsAudit(ctx context.Context, db *pgxpool.Pool, apiKeyID string, permissions []string, r *http.Request, providedPermissions any, providedMetadata any) error {
if strings.TrimSpace(apiKeyID) == "" {
return errors.New("api_key_id is missing for audit")
}
if len(permissions) == 0 {
return errors.New("permissions.users is missing for audit")
}
if providedPermissions == nil {
return errors.New("provided_permissions is missing for audit")
}
permsJSON, err := json.Marshal(providedPermissions)
if err != nil {
return fmt.Errorf("failed to marshal provided permissions: %w", err)
}
if providedMetadata == nil {
return errors.New("provided_metadata is missing for audit")
}
metadataJSON, err := json.Marshal(providedMetadata)
if err != nil {
return fmt.Errorf("failed to marshal metadata: %w", err)
}
remoteAddr := r.RemoteAddr
userAgent := r.UserAgent()
endpoint := r.Method + " " + r.URL.Path
log.Printf("permissions audit: api_key_id=%s endpoint=%s remote=%s", apiKeyID, endpoint, remoteAddr)
_, err = db.Exec(ctx, `
INSERT INTO kb_permissions_audit (
id, api_key_id, permissions_users, requested_at, endpoint, remote_addr, user_agent, provided_permissions, provided_metadata
) VALUES ($1, $2, $3, now(), $4, $5, $6, $7::jsonb, $8::jsonb)
`, generateID(), apiKeyID, permissions, endpoint, remoteAddr, userAgent, string(permsJSON), string(metadataJSON))
if err != nil {
return fmt.Errorf("failed to write permissions audit: %w", err)
}
return nil
}
func handleCollections(db *pgxpool.Pool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if db == nil {
writeError(w, http.StatusServiceUnavailable, "db_not_configured", "DATABASE_URL is not set")
return
}
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
permissions, ok := r.Context().Value(ctxPermissionsUsers).([]string)
if !ok || len(permissions) == 0 {
writeError(w, http.StatusForbidden, "forbidden", "permissions.users not set")
return
}
permFilter := map[string]interface{}{
"contains": map[string]interface{}{
"metadata.permissions.users": permissions,
},
}
whereSQL, params, _, err := buildFilterSQL(permFilter, 1)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
query := fmt.Sprintf(`
SELECT collection, COUNT(*) AS cnt
FROM kb_doc_chunks
WHERE %s
GROUP BY collection
ORDER BY collection ASC
`, whereSQL)
rows, err := db.Query(r.Context(), query, params...)
if err != nil {
writeError(w, http.StatusInternalServerError, "collections_failed", fmt.Sprintf("failed to list collections: %v", err))
return
}
defer rows.Close()
items := make([]CollectionItem, 0)
for rows.Next() {
var item CollectionItem
if err := rows.Scan(&item.Name, &item.Count); err != nil {
writeError(w, http.StatusInternalServerError, "collections_failed", "failed to read collections")
return
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
writeError(w, http.StatusInternalServerError, "collections_failed", "failed to read collections")
return
}
writeJSON(w, http.StatusOK, CollectionsResponse{Items: items})
}
}
func handleAdminCollections(db *pgxpool.Pool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if db == nil {
writeError(w, http.StatusServiceUnavailable, "db_not_configured", "DATABASE_URL is not set")
return
}
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
rows, err := db.Query(r.Context(), `
SELECT collection, COUNT(*) AS cnt
FROM kb_doc_chunks
GROUP BY collection
ORDER BY collection ASC
`)
if err != nil {
writeError(w, http.StatusInternalServerError, "collections_failed", fmt.Sprintf("failed to list collections: %v", err))
return
}
defer rows.Close()
items := make([]CollectionItem, 0)
for rows.Next() {
var item CollectionItem
if err := rows.Scan(&item.Name, &item.Count); err != nil {
writeError(w, http.StatusInternalServerError, "collections_failed", "failed to read collections")
return
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
writeError(w, http.StatusInternalServerError, "collections_failed", "failed to read collections")
return
}
writeJSON(w, http.StatusOK, CollectionsResponse{Items: items})
}
}
func handleAdminPermissionsAudit(db *pgxpool.Pool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if db == nil {
writeError(w, http.StatusServiceUnavailable, "db_not_configured", "DATABASE_URL is not set")
return
}
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
q := r.URL.Query()
since, err := parseTimeParam(q.Get("since"))
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "since must be RFC3339")
return
}
until, err := parseTimeParam(q.Get("until"))
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "until must be RFC3339")
return
}
limit, err := parseLimit(q.Get("limit"), 100, 1000)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
apiKeyID := strings.TrimSpace(q.Get("api_key_id"))
clauses := make([]string, 0, 4)
params := make([]interface{}, 0, 4)
idx := 1
if since != nil {
clauses = append(clauses, fmt.Sprintf("requested_at >= $%d", idx))
params = append(params, *since)
idx++
}
if until != nil {
clauses = append(clauses, fmt.Sprintf("requested_at <= $%d", idx))
params = append(params, *until)
idx++
}
if apiKeyID != "" {
clauses = append(clauses, fmt.Sprintf("api_key_id = $%d", idx))
params = append(params, apiKeyID)
idx++
}
whereSQL := "TRUE"
if len(clauses) > 0 {
whereSQL = strings.Join(clauses, " AND ")
}
query := fmt.Sprintf(`
SELECT id::text, api_key_id, permissions_users, requested_at, endpoint, remote_addr, user_agent, provided_permissions, provided_metadata
FROM kb_permissions_audit
WHERE %s
ORDER BY requested_at DESC
LIMIT $%d
`, whereSQL, idx)
params = append(params, limit)
rows, err := db.Query(r.Context(), query, params...)
if err != nil {
writeError(w, http.StatusInternalServerError, "audit_failed", fmt.Sprintf("failed to read audit logs: %v", err))
return
}
defer rows.Close()
items := make([]PermissionsAuditItem, 0)
for rows.Next() {
var item PermissionsAuditItem
var permsJSON []byte
var metaJSON []byte
if err := rows.Scan(&item.ID, &item.APIKeyID, &item.PermissionsUsers, &item.RequestedAt, &item.Endpoint, &item.RemoteAddr, &item.UserAgent, &permsJSON, &metaJSON); err != nil {
writeError(w, http.StatusInternalServerError, "audit_failed", "failed to read audit logs")
return
}
if len(permsJSON) > 0 {
if err := json.Unmarshal(permsJSON, &item.ProvidedPermissions); err != nil {
writeError(w, http.StatusInternalServerError, "audit_failed", "failed to parse audit logs")
return
}
}
if len(metaJSON) > 0 {
if err := json.Unmarshal(metaJSON, &item.ProvidedMetadata); err != nil {
writeError(w, http.StatusInternalServerError, "audit_failed", "failed to parse audit logs")
return
}
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
writeError(w, http.StatusInternalServerError, "audit_failed", "failed to read audit logs")
return
}
writeJSON(w, http.StatusOK, PermissionsAuditResponse{Items: items})
}
}
func handleAdminApiKeys(store APIKeyStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/admin/api-keys")
path = strings.TrimPrefix(path, "/")
if path == "" {
switch r.Method {
case http.MethodGet:
items, err := store.List(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, "list_failed", "failed to list api keys")
return
}
writeJSON(w, http.StatusOK, AdminApiKeyListResponse{Items: items})
case http.MethodPost:
var req AdminApiKeyCreateRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
rawKey, key, err := store.Create(r.Context(), req.Label, req.PermissionsUsers)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
writeJSON(w, http.StatusOK, AdminApiKeyCreateResponse{
APIKey: rawKey,
Key: key,
})
default:
w.WriteHeader(http.StatusMethodNotAllowed)
}
return
}
parts := strings.Split(path, "/")
if len(parts) >= 2 && parts[1] == "revoke" {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
key, err := store.Revoke(r.Context(), parts[0])
if err != nil {
if errors.Is(err, ErrNotFound) {
writeError(w, http.StatusNotFound, "not_found", "api key not found")
return
}
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
writeJSON(w, http.StatusOK, key)
return
}
if len(parts) == 1 && r.Method == http.MethodPatch {
var req AdminApiKeyUpdateRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
updated, err := store.Update(r.Context(), parts[0], req.Label, req.PermissionsUsers)
if err != nil {
if errors.Is(err, ErrNotFound) {
writeError(w, http.StatusNotFound, "not_found", "api key not found")
return
}
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
return
}
writeJSON(w, http.StatusOK, updated)
return
}
w.WriteHeader(http.StatusNotFound)
}
}
func handleAdminUI(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
_, _ = w.Write([]byte(adminHTML))
}
func writeJSON(w http.ResponseWriter, status int, v interface{}) {
w.Header().Set("Content-Type", "application/json")
payload, err := json.Marshal(v)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"error":"failed to encode response","code":"encode_failed"}`))
return
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(payload)))
w.WriteHeader(status)
if _, err := w.Write(payload); err != nil {
log.Printf("writeJSON failed: %v", err)
}
}
func writeError(w http.ResponseWriter, status int, code, msg string) {
writeJSON(w, status, ErrorResponse{Error: msg, Code: code})
}
func decodeJSON(r *http.Request, v interface{}) error {
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(v); err != nil {
return err
}
return nil
}
const adminHTML = `<!doctype html>
<html lang="ja">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>pgvecter API Admin</title>
<style>
body { font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, sans-serif; margin: 24px; }
h1 { margin-bottom: 12px; }
section { margin-bottom: 24px; }
table { border-collapse: collapse; width: 100%; }
th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
th { background: #f5f5f5; }
input, textarea, button { font-size: 14px; }
textarea { width: 100%; height: 60px; }
.row { display: flex; gap: 12px; }
.row > div { flex: 1; }
.muted { color: #666; font-size: 12px; }
</style>
</head>
<body>
<h1>pgvecter API Admin</h1>
<p class="muted">X-ADMIN-API-KEY を使って管理APIを呼び出しますadmin_key クエリは localhost のみ)。</p>
<div id="status" class="muted"></div>
<section id="editSection">
<h2>接続</h2>
<div class="row">
<div>
<label>API Base URL</label><br />
<input id="baseUrl" value="http://localhost:8080" style="width:100%" />
</div>
<div>
<label>Admin API Key</label><br />
<input id="adminKey" type="password" style="width:100%" />
</div>
</div>
<button onclick="loadKeys()">一覧を取得</button>
</section>
<section>
<h2>発行</h2>
<div class="row">
<div>
<label>Label</label><br />
<input id="newLabel" style="width:100%" />
</div>
<div>
<label>permissions.users改行区切り</label><br />
<textarea id="newUsers"></textarea>
</div>
</div>
<button onclick="createKey()">発行</button>
<pre id="issuedKey"></pre>
<button onclick="copyIssuedKey()">コピー</button>
</section>
<section>
<h2>APIキー一覧</h2>
<table id="keysTable">
<thead>
<tr>
<th>ID</th>
<th>Label</th>
<th>Users</th>
<th>Status</th>
<th>Created</th>
<th>Last Used</th>
<th>Actions</th>
</tr>
</thead>
<tbody></tbody>
</table>
</section>
<section>
<h2>コレクション一覧</h2>
<table id="collectionsTable">
<thead>
<tr>
<th>Name</th>
<th>Count</th>
</tr>
</thead>
<tbody></tbody>
</table>
</section>
<section>
<h2>編集</h2>
<div class="row">
<div>
<label>Key ID</label><br />
<input id="editId" style="width:100%" />
</div>
<div>
<label>Label</label><br />
<input id="editLabel" style="width:100%" />
</div>
<div>
<label>permissions.users改行区切り</label><br />
<textarea id="editUsers"></textarea>
</div>
</div>
<button onclick="updateKey()">更新</button>
</section>
<script>
function getBaseUrl() { return document.getElementById('baseUrl').value.trim(); }
function getAdminKey() {
const input = document.getElementById('adminKey').value.trim();
if (input) { return input; }
const host = window.location.hostname;
if (host === 'localhost' || host === '127.0.0.1' || host === '::1') {
const params = new URLSearchParams(window.location.search);
return params.get('admin_key') || '';
}
return '';
}
function headers() {
return { 'Content-Type': 'application/json', 'X-ADMIN-API-KEY': getAdminKey() };
}
function parseUsers(text) {
return text.split('\n').map(s => s.trim()).filter(Boolean);
}
async function loadKeys() {
const res = await fetch(getBaseUrl() + '/admin/api-keys', { headers: headers() });
if (!res.ok) { setStatus('一覧取得に失敗しました'); return; }
const data = await res.json();
const tbody = document.querySelector('#keysTable tbody');
tbody.innerHTML = '';
for (const item of data.items) {
const tr = document.createElement('tr');
tr.innerHTML =
'<td>' + item.id + '</td>' +
'<td>' + item.label + '</td>' +
'<td>' + (item.permissions_users || []).join(', ') + '</td>' +
'<td>' + item.status + '</td>' +
'<td>' + item.created_at + '</td>' +
'<td>' + (item.last_used_at || '') + '</td>' +
'<td><button class="revoke">失効</button> <button class="edit">編集</button></td>';
tr.querySelector('button.revoke').addEventListener('click', () => revokeKey(item.id));
tr.querySelector('button.edit').addEventListener('click', () => fillEditForm(item));
tbody.appendChild(tr);
}
}
async function loadCollections() {
const res = await fetch(getBaseUrl() + '/admin/collections', { headers: headers() });
if (!res.ok) { setStatus('コレクション取得に失敗しました', 'error'); return; }
const data = await res.json();
const tbody = document.querySelector('#collectionsTable tbody');
tbody.innerHTML = '';
for (const item of data.items || []) {
const tr = document.createElement('tr');
tr.innerHTML =
'<td>' + item.name + '</td>' +
'<td>' + item.count + '</td>';
tbody.appendChild(tr);
}
}
async function createKey() {
const label = document.getElementById('newLabel').value.trim();
const users = parseUsers(document.getElementById('newUsers').value);
const res = await fetch(getBaseUrl() + '/admin/api-keys', {
method: 'POST',
headers: headers(),
body: JSON.stringify({ label: label, permissions_users: users })
});
if (!res.ok) { setStatus('発行に失敗しました'); return; }
const data = await res.json();
document.getElementById('issuedKey').textContent = 'API Key: ' + data.api_key;
setStatus('発行しました', 'ok');
await loadKeys();
await loadCollections();
}
async function revokeKey(id) {
const res = await fetch(getBaseUrl() + '/admin/api-keys/' + id + '/revoke', {
method: 'POST',
headers: headers()
});
if (!res.ok) { setStatus('失効に失敗しました', 'error'); return; }
setStatus('失効しました', 'ok');
await loadKeys();
await loadCollections();
}
function fillEditForm(item) {
document.getElementById('editId').value = item.id;
document.getElementById('editLabel').value = item.label;
document.getElementById('editUsers').value = (item.permissions_users || []).join('\n');
const section = document.getElementById('editSection');
if (section && section.scrollIntoView) {
section.scrollIntoView({ behavior: 'smooth', block: 'start' });
}
}
async function updateKey() {
const id = document.getElementById('editId').value.trim();
const label = document.getElementById('editLabel').value.trim();
const users = parseUsers(document.getElementById('editUsers').value);
if (!id) { setStatus('Key IDが必要です', 'error'); return; }
const res = await fetch(getBaseUrl() + '/admin/api-keys/' + id, {
method: 'PATCH',
headers: headers(),
body: JSON.stringify({ label: label, permissions_users: users })
});
if (!res.ok) { setStatus('更新に失敗しました', 'error'); return; }
setStatus('更新しました', 'ok');
await loadKeys();
await loadCollections();
}
function setStatus(msg, type) {
const el = document.getElementById('status');
if (!el) { return; }
el.textContent = msg;
el.style.color = type === 'error' ? '#b00020' : '#0b6e4f';
clearTimeout(window.__statusTimer);
window.__statusTimer = setTimeout(() => { el.textContent = ''; }, 3000);
}
async function copyIssuedKey() {
const text = document.getElementById('issuedKey').textContent.replace('API Key: ', '').trim();
if (!text) { setStatus('コピー対象がありません', 'error'); return; }
try {
await navigator.clipboard.writeText(text);
setStatus('APIキーをコピーしました', 'ok');
} catch (e) {
setStatus('コピーに失敗しました', 'error');
}
}
window.addEventListener('DOMContentLoaded', () => {
loadKeys();
loadCollections();
});
</script>
</body>
</html>`
func withAdminUIAuth(adminKey string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if adminKey == "" {
writeError(w, http.StatusServiceUnavailable, "admin_key_not_configured", "ADMIN_API_KEY is not set")
return
}
key := r.Header.Get("X-ADMIN-API-KEY")
if key == "" {
if isLocalhostRequest(r) {
key = r.URL.Query().Get("admin_key")
} else if r.URL.Query().Get("admin_key") != "" {
writeError(w, http.StatusBadRequest, "invalid_request", "admin_key query parameter is only allowed from localhost")
return
}
}
if key == "" || key != adminKey {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid admin api key")
return
}
next.ServeHTTP(w, r)
})
}
func buildFilterSQL(filter map[string]interface{}, startIdx int) (string, []interface{}, int, error) {
if filter == nil || len(filter) == 0 {
return "TRUE", nil, startIdx, nil
}
if v, ok := filter["and"]; ok {
items, ok := v.([]interface{})
if !ok || len(items) == 0 {
return "", nil, startIdx, errors.New("and must be a non-empty array")
}
var parts []string
var params []interface{}
idx := startIdx
for _, item := range items {
m, ok := item.(map[string]interface{})
if !ok {
return "", nil, startIdx, errors.New("and items must be objects")
}
sql, p, next, err := buildFilterSQL(m, idx)
if err != nil {
return "", nil, startIdx, err
}
parts = append(parts, "("+sql+")")
params = append(params, p...)
idx = next
}
return strings.Join(parts, " AND "), params, idx, nil
}
if v, ok := filter["or"]; ok {
items, ok := v.([]interface{})
if !ok || len(items) == 0 {
return "", nil, startIdx, errors.New("or must be a non-empty array")
}
var parts []string
var params []interface{}
idx := startIdx
for _, item := range items {
m, ok := item.(map[string]interface{})
if !ok {
return "", nil, startIdx, errors.New("or items must be objects")
}
sql, p, next, err := buildFilterSQL(m, idx)
if err != nil {
return "", nil, startIdx, err
}
parts = append(parts, "("+sql+")")
params = append(params, p...)
idx = next
}
return strings.Join(parts, " OR "), params, idx, nil
}
if v, ok := filter["eq"]; ok {
return buildSimpleOp("eq", v, startIdx)
}
if v, ok := filter["in"]; ok {
return buildSimpleOp("in", v, startIdx)
}
if v, ok := filter["contains"]; ok {
return buildSimpleOp("contains", v, startIdx)
}
if v, ok := filter["exists"]; ok {
return buildSimpleOp("exists", v, startIdx)
}
if v, ok := filter["lt"]; ok {
return buildSimpleOp("lt", v, startIdx)
}
if v, ok := filter["lte"]; ok {
return buildSimpleOp("lte", v, startIdx)
}
if v, ok := filter["gt"]; ok {
return buildSimpleOp("gt", v, startIdx)
}
if v, ok := filter["gte"]; ok {
return buildSimpleOp("gte", v, startIdx)
}
return "", nil, startIdx, errors.New("unsupported filter operator")
}
func buildSimpleOp(op string, v interface{}, startIdx int) (string, []interface{}, int, error) {
obj, ok := v.(map[string]interface{})
if !ok || len(obj) != 1 {
return "", nil, startIdx, fmt.Errorf("%s must be an object with a single field", op)
}
var path string
var value interface{}
for k, val := range obj {
path = k
value = val
}
keys, err := parseMetadataPath(path)
if err != nil {
return "", nil, startIdx, err
}
pathLiteral := strings.Join(keys, ",")
switch op {
case "eq":
sql := fmt.Sprintf("metadata #>> '{%s}' = $%d", pathLiteral, startIdx)
return sql, []interface{}{fmt.Sprint(value)}, startIdx + 1, nil
case "in":
list, ok := value.([]interface{})
if !ok || len(list) == 0 {
return "", nil, startIdx, errors.New("in must be a non-empty array")
}
values := make([]string, 0, len(list))
for _, item := range list {
values = append(values, fmt.Sprint(item))
}
sql := fmt.Sprintf("metadata #>> '{%s}' = ANY($%d)", pathLiteral, startIdx)
return sql, []interface{}{values}, startIdx + 1, nil
case "contains":
switch val := value.(type) {
case string:
sql := fmt.Sprintf("metadata #>> '{%s}' ILIKE '%%' || $%d || '%%'", pathLiteral, startIdx)
return sql, []interface{}{val}, startIdx + 1, nil
case []interface{}:
if len(val) == 0 {
return "", nil, startIdx, errors.New("contains array must be non-empty")
}
raw, err := json.Marshal(val)
if err != nil {
return "", nil, startIdx, err
}
sql := fmt.Sprintf("metadata #> '{%s}' @> $%d::jsonb", pathLiteral, startIdx)
return sql, []interface{}{string(raw)}, startIdx + 1, nil
case []string:
if len(val) == 0 {
return "", nil, startIdx, errors.New("contains array must be non-empty")
}
raw, err := json.Marshal(val)
if err != nil {
return "", nil, startIdx, err
}
sql := fmt.Sprintf("metadata #> '{%s}' @> $%d::jsonb", pathLiteral, startIdx)
return sql, []interface{}{string(raw)}, startIdx + 1, nil
default:
return "", nil, startIdx, errors.New("contains must be string or array")
}
case "exists":
if len(keys) == 1 {
sql := fmt.Sprintf("metadata ? '%s'", keys[0])
return sql, nil, startIdx, nil
}
parent := strings.Join(keys[:len(keys)-1], ",")
leaf := keys[len(keys)-1]
sql := fmt.Sprintf("(metadata #> '{%s}') ? '%s'", parent, leaf)
return sql, nil, startIdx, nil
case "lt":
sql := fmt.Sprintf("metadata #>> '{%s}' < $%d", pathLiteral, startIdx)
return sql, []interface{}{fmt.Sprint(value)}, startIdx + 1, nil
case "lte":
sql := fmt.Sprintf("metadata #>> '{%s}' <= $%d", pathLiteral, startIdx)
return sql, []interface{}{fmt.Sprint(value)}, startIdx + 1, nil
case "gt":
sql := fmt.Sprintf("metadata #>> '{%s}' > $%d", pathLiteral, startIdx)
return sql, []interface{}{fmt.Sprint(value)}, startIdx + 1, nil
case "gte":
sql := fmt.Sprintf("metadata #>> '{%s}' >= $%d", pathLiteral, startIdx)
return sql, []interface{}{fmt.Sprint(value)}, startIdx + 1, nil
default:
return "", nil, startIdx, errors.New("unsupported operator")
}
}
func parseMetadataPath(path string) ([]string, error) {
if !strings.HasPrefix(path, "metadata.") {
return nil, errors.New("filter path must start with metadata.")
}
parts := strings.Split(path, ".")[1:]
if len(parts) == 0 {
return nil, errors.New("filter path is invalid")
}
for _, part := range parts {
if part == "" || !isSafeIdent(part) {
return nil, errors.New("filter path contains invalid characters")
}
}
return parts, nil
}
func isSafeIdent(s string) bool {
for _, r := range s {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' {
continue
}
return false
}
return true
}
func vectorLiteral(vec []float64) string {
parts := make([]string, 0, len(vec))
for _, v := range vec {
parts = append(parts, fmt.Sprintf("%.8f", v))
}
return "[" + strings.Join(parts, ",") + "]"
}
func distanceToSimilarity(distance float64) float64 {
// pgvector cosine distance is in [0, 2]; convert to similarity in [0, 1]
sim := 1 - (distance / 2.0)
if sim < 0 {
return 0
}
if sim > 1 {
return 1
}
return sim
}
type embeddingRequest struct {
Input string `json:"input"`
Model string `json:"model"`
EncodingFormat string `json:"encoding_format,omitempty"`
}
type embeddingResponse struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
}
func createEmbedding(ctx context.Context, text string) ([]float64, error) {
return createEmbeddingWithDim(ctx, text, embeddingDim())
}
func createEmbeddingWithDim(ctx context.Context, text string, dim int) ([]float64, error) {
provider := strings.ToLower(strings.TrimSpace(os.Getenv("EMBEDDING_PROVIDER")))
llamaURL := strings.TrimSpace(os.Getenv("LLAMA_CPP_URL"))
if provider == "" && llamaURL != "" {
provider = "llamacpp"
}
if provider == "" {
provider = "openai"
}
switch provider {
case "llamacpp":
return createLlamaCppEmbedding(ctx, text, dim)
case "openai":
return createOpenAIEmbedding(ctx, text)
default:
return nil, fmt.Errorf("unsupported EMBEDDING_PROVIDER: %s", provider)
}
}
func createLlamaCppEmbedding(ctx context.Context, text string, dim int) ([]float64, error) {
baseURL := strings.TrimSpace(os.Getenv("LLAMA_CPP_URL"))
if baseURL == "" {
baseURL = "http://127.0.0.1:8092"
}
endpoint := strings.TrimRight(baseURL, "/") + "/embedding"
payload, err := json.Marshal(map[string]string{
"content": text,
})
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(payload)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("llama.cpp embedding request failed with status %d", resp.StatusCode)
}
var out []struct {
Embedding [][]float64 `json:"embedding"`
}
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, err
}
if len(out) == 0 || len(out[0].Embedding) == 0 || len(out[0].Embedding[0]) == 0 {
return nil, errors.New("llama.cpp embedding response is empty")
}
emb := out[0].Embedding[0]
if len(emb) != dim {
return nil, fmt.Errorf("embedding dimension mismatch: got %d, expected %d", len(emb), dim)
}
return emb, nil
}
func createOpenAIEmbedding(ctx context.Context, text string) ([]float64, error) {
apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY"))
if apiKey == "" {
return nil, errors.New("OPENAI_API_KEY is not set")
}
model := strings.TrimSpace(os.Getenv("EMBEDDING_MODEL"))
if model == "" {
model = "text-embedding-3-small"
}
reqBody := embeddingRequest{
Input: text,
Model: model,
EncodingFormat: "float",
}
payload, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.openai.com/v1/embeddings", strings.NewReader(string(payload)))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("embedding request failed with status %d", resp.StatusCode)
}
var out embeddingResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, err
}
if len(out.Data) == 0 || len(out.Data[0].Embedding) == 0 {
return nil, errors.New("embedding response is empty")
}
return out.Data[0].Embedding, nil
}
func embeddingDim() int {
raw := strings.TrimSpace(os.Getenv("EMBEDDING_DIM"))
if raw == "" {
return 1024
}
n, err := strconv.Atoi(raw)
if err != nil || n <= 0 {
return 1024
}
return n
}
func parseTimeParam(raw string) (*time.Time, error) {
if strings.TrimSpace(raw) == "" {
return nil, nil
}
t, err := time.Parse(time.RFC3339, raw)
if err != nil {
return nil, err
}
return &t, nil
}
func parseLimit(raw string, def, max int) (int, error) {
if strings.TrimSpace(raw) == "" {
return def, nil
}
n, err := strconv.Atoi(raw)
if err != nil || n <= 0 {
return 0, errors.New("limit must be a positive integer")
}
if n > max {
return max, nil
}
return n, nil
}
func isLocalhostRequest(r *http.Request) bool {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
host = r.RemoteAddr
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
return ip.IsLoopback()
}
func hasMetadataPermissions(meta map[string]interface{}) bool {
if meta == nil {
return false
}
_, ok := meta["permissions"]
return ok
}
func filterHasMetadataPermissions(filter map[string]interface{}) bool {
return valueHasMetadataPermissions(filter)
}
func valueHasMetadataPermissions(value interface{}) bool {
switch v := value.(type) {
case map[string]interface{}:
for k, child := range v {
if strings.HasPrefix(k, "metadata.permissions") {
return true
}
if valueHasMetadataPermissions(child) {
return true
}
}
case []interface{}:
for _, item := range v {
if valueHasMetadataPermissions(item) {
return true
}
}
}
return false
}
func fetchEmbeddingDim(ctx context.Context, db *pgxpool.Pool) (int, error) {
var typeStr string
err := db.QueryRow(ctx, `
SELECT format_type(a.atttypid, a.atttypmod)
FROM pg_attribute a
JOIN pg_class c ON a.attrelid = c.oid
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = 'public'
AND c.relname = 'kb_doc_chunks'
AND a.attname = 'embedding'
AND a.attnum > 0
AND NOT a.attisdropped
`).Scan(&typeStr)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return 0, errors.New("kb_doc_chunks.embedding not found")
}
return 0, err
}
dim, err := parseVectorTypeDim(typeStr)
if err != nil {
return 0, err
}
return dim, nil
}
func verifyEmbeddingDim(db *pgxpool.Pool, expected int) error {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
actual, err := fetchEmbeddingDim(ctx, db)
if err != nil {
return err
}
if actual != expected {
return fmt.Errorf("embedding dimension mismatch: EMBEDDING_DIM=%d db=%d", expected, actual)
}
return nil
}
func parseVectorTypeDim(typeStr string) (int, error) {
s := strings.TrimSpace(typeStr)
if !strings.HasPrefix(s, "vector(") || !strings.HasSuffix(s, ")") {
return 0, fmt.Errorf("invalid embedding type: %s", s)
}
inner := strings.TrimSuffix(strings.TrimPrefix(s, "vector("), ")")
dim, err := strconv.Atoi(inner)
if err != nil || dim <= 0 {
return 0, fmt.Errorf("invalid embedding type: %s", s)
}
return dim, nil
}
func logConnectionInfo(db *pgxpool.Pool) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
var user string
var dbname string
if err := db.QueryRow(ctx, "SELECT current_user, current_database()").Scan(&user, &dbname); err != nil {
log.Printf("kb.search db user check failed: %v", err)
return
}
log.Printf("kb.search db user=%s db=%s", user, dbname)
}
func splitIntoChunks(text string, size int, overlap int) []string {
if size <= 0 {
return []string{text}
}
text = strings.ReplaceAll(text, "\r\n", "\n")
paras := splitParagraphs(text)
var chunks []string
var buf []rune
flush := func() {
if len(buf) == 0 {
return
}
chunks = append(chunks, string(buf))
if overlap > 0 && len(buf) > overlap {
buf = append([]rune{}, buf[len(buf)-overlap:]...)
} else {
buf = buf[:0]
}
}
for _, p := range paras {
rp := []rune(p)
if len(rp) > size {
flush()
sub := splitByRunes(p, size, overlap)
chunks = append(chunks, sub...)
buf = buf[:0]
continue
}
if len(buf)+len(rp) > size {
flush()
}
if len(buf) > 0 {
buf = append(buf, '\n', '\n')
}
buf = append(buf, rp...)
}
flush()
if len(chunks) == 0 {
return []string{""}
}
return chunks
}
func splitParagraphs(text string) []string {
lines := strings.Split(text, "\n")
var paras []string
var cur []string
for _, line := range lines {
if strings.TrimSpace(line) == "" {
if len(cur) > 0 {
paras = append(paras, strings.Join(cur, "\n"))
cur = cur[:0]
}
continue
}
cur = append(cur, line)
}
if len(cur) > 0 {
paras = append(paras, strings.Join(cur, "\n"))
}
return paras
}
func splitByRunes(text string, size int, overlap int) []string {
r := []rune(text)
if len(r) <= size {
return []string{text}
}
if overlap < 0 {
overlap = 0
}
if overlap >= size {
overlap = size / 4
}
var out []string
for i := 0; i < len(r); {
end := i + size
if end > len(r) {
end = len(r)
}
out = append(out, string(r[i:end]))
if end == len(r) {
break
}
i = end - overlap
if i < 0 {
i = 0
}
}
return out
}
func validateCollection(name string) error {
if name == "" || len(name) > 50 {
return errors.New("invalid length")
}
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' {
continue
}
return errors.New("invalid character")
}
return nil
}
func initStore() (APIKeyStore, *pgxpool.Pool, func(), error) {
dbURL := os.Getenv("DATABASE_URL")
if strings.TrimSpace(dbURL) == "" {
log.Println("warning: DATABASE_URL not set; using in-memory api key store")
return NewMemoryKeyStore(), nil, func() {}, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
pool, err := pgxpool.New(ctx, dbURL)
if err != nil {
return nil, nil, func() {}, err
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, nil, func() {}, err
}
store := NewPostgresKeyStore(pool)
if err := store.EnsureSchema(ctx); err != nil {
pool.Close()
return nil, nil, func() {}, err
}
return store, pool, pool.Close, nil
}
func main() {
adminKey := os.Getenv("ADMIN_API_KEY")
if adminKey == "" {
log.Println("warning: ADMIN_API_KEY not set; admin endpoints will return 503")
}
store, db, closeStore, err := initStore()
if err != nil {
log.Fatal(err)
}
defer closeStore()
if db != nil {
if err := verifyEmbeddingDim(db, embeddingDim()); err != nil {
log.Fatal(err)
}
}
mux := http.NewServeMux()
mux.HandleFunc("/health", handleHealth)
mux.Handle("/db/query", withAPIKeyAuth(store, http.HandlerFunc(handleNotImplemented)))
mux.Handle("/kb/upsert", withAPIKeyAuth(store, handleKbUpsert(db)))
mux.Handle("/kb/search", withAPIKeyAuth(store, handleKbSearch(db)))
mux.Handle("/kb/delete", withAPIKeyAuth(store, handleKbDelete(db)))
mux.Handle("/collections", withAPIKeyAuth(store, handleCollections(db)))
adminHandler := withAdminAuth(adminKey, http.HandlerFunc(handleAdminApiKeys(store)))
mux.Handle("/admin/api-keys", adminHandler)
mux.Handle("/admin/api-keys/", adminHandler)
mux.Handle("/admin/collections", withAdminAuth(adminKey, handleAdminCollections(db)))
mux.Handle("/admin/audit/permissions", withAdminAuth(adminKey, handleAdminPermissionsAudit(db)))
mux.Handle("/admin", withAdminUIAuth(adminKey, http.HandlerFunc(handleAdminUI)))
port := os.Getenv("PORT")
if port == "" {
port = "8080"
}
addr := ":" + port
log.Printf("pgvecter API listening on %s", addr)
if err := http.ListenAndServe(addr, mux); err != nil {
log.Fatal(err)
}
}