2242 lines
67 KiB
Go
2242 lines
67 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"crypto/sha256"
|
||
"encoding/base64"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"log"
|
||
"math"
|
||
"net"
|
||
"net/http"
|
||
"os"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/jackc/pgx/v5"
|
||
"github.com/jackc/pgx/v5/pgxpool"
|
||
)
|
||
|
||
type ctxKey string
|
||
|
||
const (
|
||
ctxPermissionsUsers ctxKey = "permissions_users"
|
||
ctxAPIKeyID ctxKey = "api_key_id"
|
||
)
|
||
|
||
var ErrNotFound = errors.New("not found")
|
||
|
||
type ErrorResponse struct {
|
||
Error string `json:"error"`
|
||
Code string `json:"code"`
|
||
}
|
||
|
||
type HealthResponse struct {
|
||
Status string `json:"status"`
|
||
Timestamp time.Time `json:"timestamp"`
|
||
}
|
||
|
||
type KbSearchRequest struct {
|
||
Query string `json:"query"`
|
||
QueryEmbedding []float64 `json:"query_embedding"`
|
||
TopK int `json:"top_k"`
|
||
Filter map[string]interface{} `json:"filter"`
|
||
Collection string `json:"collection"`
|
||
Collections []string `json:"collections"`
|
||
}
|
||
|
||
type KbSearchResponse struct {
|
||
Results []KbSearchResult `json:"results"`
|
||
}
|
||
|
||
type KbSearchResult struct {
|
||
ID string `json:"id"`
|
||
Collection string `json:"collection"`
|
||
Content string `json:"content"`
|
||
Metadata map[string]interface{} `json:"metadata"`
|
||
Score float64 `json:"score"`
|
||
}
|
||
|
||
type CollectionItem struct {
|
||
Name string `json:"name"`
|
||
Count int `json:"count"`
|
||
}
|
||
|
||
type CollectionsResponse struct {
|
||
Items []CollectionItem `json:"items"`
|
||
}
|
||
|
||
type PermissionsAuditItem struct {
|
||
ID string `json:"id"`
|
||
APIKeyID string `json:"api_key_id"`
|
||
PermissionsUsers []string `json:"permissions_users"`
|
||
RequestedAt time.Time `json:"requested_at"`
|
||
Endpoint string `json:"endpoint"`
|
||
RemoteAddr string `json:"remote_addr,omitempty"`
|
||
UserAgent string `json:"user_agent,omitempty"`
|
||
ProvidedPermissions map[string]interface{} `json:"provided_permissions"`
|
||
ProvidedMetadata map[string]interface{} `json:"provided_metadata"`
|
||
}
|
||
|
||
type PermissionsAuditResponse struct {
|
||
Items []PermissionsAuditItem `json:"items"`
|
||
}
|
||
|
||
type KbUpsertRequest struct {
|
||
ID string `json:"id"`
|
||
DocID string `json:"doc_id"`
|
||
ChunkIndex int `json:"chunk_index"`
|
||
Collection string `json:"collection"`
|
||
Content string `json:"content"`
|
||
Metadata map[string]interface{} `json:"metadata"`
|
||
Embedding []float64 `json:"embedding"`
|
||
AutoChunk bool `json:"auto_chunk"`
|
||
ChunkSize int `json:"chunk_size"`
|
||
Overlap int `json:"chunk_overlap"`
|
||
}
|
||
|
||
type KbUpsertResponse struct {
|
||
ID string `json:"id"`
|
||
DocID string `json:"doc_id"`
|
||
Updated bool `json:"updated"`
|
||
Chunks []struct {
|
||
ID string `json:"id"`
|
||
ChunkIndex int `json:"chunk_index"`
|
||
} `json:"chunks,omitempty"`
|
||
}
|
||
|
||
type KbDeleteRequest struct {
|
||
ID string `json:"id"`
|
||
DocID string `json:"doc_id"`
|
||
Collection string `json:"collection"`
|
||
Filter map[string]interface{} `json:"filter"`
|
||
DryRun bool `json:"dry_run"`
|
||
}
|
||
|
||
type KbDeleteResponse struct {
|
||
Deleted int64 `json:"deleted"`
|
||
}
|
||
|
||
type AdminApiKey struct {
|
||
ID string `json:"id"`
|
||
Label string `json:"label"`
|
||
PermissionsUsers []string `json:"permissions_users"`
|
||
Status string `json:"status"`
|
||
CreatedAt time.Time `json:"created_at"`
|
||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||
}
|
||
|
||
type AdminApiKeyListResponse struct {
|
||
Items []AdminApiKey `json:"items"`
|
||
}
|
||
|
||
type AdminApiKeyCreateRequest struct {
|
||
Label string `json:"label"`
|
||
PermissionsUsers []string `json:"permissions_users"`
|
||
}
|
||
|
||
type AdminApiKeyCreateResponse struct {
|
||
APIKey string `json:"api_key"`
|
||
Key AdminApiKey `json:"key"`
|
||
}
|
||
|
||
type AdminApiKeyUpdateRequest struct {
|
||
Label *string `json:"label"`
|
||
PermissionsUsers []string `json:"permissions_users"`
|
||
}
|
||
|
||
type APIKeyStore interface {
|
||
Create(ctx context.Context, label string, permissions []string) (string, AdminApiKey, error)
|
||
List(ctx context.Context) ([]AdminApiKey, error)
|
||
Update(ctx context.Context, id string, label *string, permissions []string) (AdminApiKey, error)
|
||
Revoke(ctx context.Context, id string) (AdminApiKey, error)
|
||
Authenticate(ctx context.Context, apiKey string) (*AdminApiKey, bool, error)
|
||
}
|
||
|
||
type MemoryKeyStore struct {
|
||
mu sync.Mutex
|
||
keys map[string]*AdminApiKey
|
||
hashToID map[string]string
|
||
}
|
||
|
||
func NewMemoryKeyStore() *MemoryKeyStore {
|
||
return &MemoryKeyStore{
|
||
keys: make(map[string]*AdminApiKey),
|
||
hashToID: make(map[string]string),
|
||
}
|
||
}
|
||
|
||
func (s *MemoryKeyStore) Create(_ context.Context, label string, permissions []string) (string, AdminApiKey, error) {
|
||
if strings.TrimSpace(label) == "" {
|
||
return "", AdminApiKey{}, errors.New("label is required")
|
||
}
|
||
if len(permissions) == 0 {
|
||
return "", AdminApiKey{}, errors.New("permissions_users is required")
|
||
}
|
||
rawKey, err := generateAPIKey()
|
||
if err != nil {
|
||
return "", AdminApiKey{}, err
|
||
}
|
||
keyHash := hashKey(rawKey)
|
||
now := time.Now().UTC()
|
||
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
id := generateID()
|
||
apiKey := AdminApiKey{
|
||
ID: id,
|
||
Label: label,
|
||
PermissionsUsers: append([]string(nil), permissions...),
|
||
Status: "active",
|
||
CreatedAt: now,
|
||
LastUsedAt: nil,
|
||
}
|
||
s.keys[id] = &apiKey
|
||
s.hashToID[keyHash] = id
|
||
return rawKey, apiKey, nil
|
||
}
|
||
|
||
func (s *MemoryKeyStore) List(_ context.Context) ([]AdminApiKey, error) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
items := make([]AdminApiKey, 0, len(s.keys))
|
||
for _, k := range s.keys {
|
||
items = append(items, *k)
|
||
}
|
||
sort.Slice(items, func(i, j int) bool {
|
||
return items[i].CreatedAt.Before(items[j].CreatedAt)
|
||
})
|
||
return items, nil
|
||
}
|
||
|
||
func (s *MemoryKeyStore) Update(_ context.Context, id string, label *string, permissions []string) (AdminApiKey, error) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
k, ok := s.keys[id]
|
||
if !ok {
|
||
return AdminApiKey{}, ErrNotFound
|
||
}
|
||
if label != nil {
|
||
if strings.TrimSpace(*label) == "" {
|
||
return AdminApiKey{}, errors.New("label is required")
|
||
}
|
||
k.Label = *label
|
||
}
|
||
if permissions != nil {
|
||
if len(permissions) == 0 {
|
||
return AdminApiKey{}, errors.New("permissions_users is required")
|
||
}
|
||
k.PermissionsUsers = append([]string(nil), permissions...)
|
||
}
|
||
return *k, nil
|
||
}
|
||
|
||
func (s *MemoryKeyStore) Revoke(_ context.Context, id string) (AdminApiKey, error) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
k, ok := s.keys[id]
|
||
if !ok {
|
||
return AdminApiKey{}, ErrNotFound
|
||
}
|
||
k.Status = "revoked"
|
||
return *k, nil
|
||
}
|
||
|
||
func (s *MemoryKeyStore) Authenticate(_ context.Context, apiKey string) (*AdminApiKey, bool, error) {
|
||
keyHash := hashKey(apiKey)
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
id, ok := s.hashToID[keyHash]
|
||
if !ok {
|
||
return nil, false, nil
|
||
}
|
||
k, ok := s.keys[id]
|
||
if !ok || k.Status != "active" {
|
||
return nil, false, nil
|
||
}
|
||
now := time.Now().UTC()
|
||
k.LastUsedAt = &now
|
||
return k, true, nil
|
||
}
|
||
|
||
type PostgresKeyStore struct {
|
||
pool *pgxpool.Pool
|
||
}
|
||
|
||
func NewPostgresKeyStore(pool *pgxpool.Pool) *PostgresKeyStore {
|
||
return &PostgresKeyStore{pool: pool}
|
||
}
|
||
|
||
func (s *PostgresKeyStore) EnsureSchema(ctx context.Context) error {
|
||
schema := `
|
||
CREATE TABLE IF NOT EXISTS api_keys (
|
||
id text PRIMARY KEY,
|
||
label text NOT NULL,
|
||
api_key_hash text NOT NULL UNIQUE,
|
||
permissions_users text[] NOT NULL,
|
||
status text NOT NULL,
|
||
created_at timestamptz NOT NULL,
|
||
last_used_at timestamptz
|
||
);
|
||
`
|
||
_, err := s.pool.Exec(ctx, schema)
|
||
return err
|
||
}
|
||
|
||
func (s *PostgresKeyStore) Create(ctx context.Context, label string, permissions []string) (string, AdminApiKey, error) {
|
||
if strings.TrimSpace(label) == "" {
|
||
return "", AdminApiKey{}, errors.New("label is required")
|
||
}
|
||
if len(permissions) == 0 {
|
||
return "", AdminApiKey{}, errors.New("permissions_users is required")
|
||
}
|
||
rawKey, err := generateAPIKey()
|
||
if err != nil {
|
||
return "", AdminApiKey{}, err
|
||
}
|
||
keyHash := hashKey(rawKey)
|
||
now := time.Now().UTC()
|
||
id := generateID()
|
||
|
||
var key AdminApiKey
|
||
row := s.pool.QueryRow(ctx, `
|
||
INSERT INTO api_keys (id, label, api_key_hash, permissions_users, status, created_at)
|
||
VALUES ($1, $2, $3, $4, 'active', $5)
|
||
RETURNING id, label, permissions_users, status, created_at, last_used_at
|
||
`, id, label, keyHash, permissions, now)
|
||
if err := row.Scan(&key.ID, &key.Label, &key.PermissionsUsers, &key.Status, &key.CreatedAt, &key.LastUsedAt); err != nil {
|
||
return "", AdminApiKey{}, err
|
||
}
|
||
return rawKey, key, nil
|
||
}
|
||
|
||
func (s *PostgresKeyStore) List(ctx context.Context) ([]AdminApiKey, error) {
|
||
rows, err := s.pool.Query(ctx, `
|
||
SELECT id, label, permissions_users, status, created_at, last_used_at
|
||
FROM api_keys
|
||
ORDER BY created_at ASC
|
||
`)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
var items []AdminApiKey
|
||
for rows.Next() {
|
||
var k AdminApiKey
|
||
if err := rows.Scan(&k.ID, &k.Label, &k.PermissionsUsers, &k.Status, &k.CreatedAt, &k.LastUsedAt); err != nil {
|
||
return nil, err
|
||
}
|
||
items = append(items, k)
|
||
}
|
||
return items, rows.Err()
|
||
}
|
||
|
||
func (s *PostgresKeyStore) Update(ctx context.Context, id string, label *string, permissions []string) (AdminApiKey, error) {
|
||
if label == nil && permissions == nil {
|
||
return AdminApiKey{}, errors.New("no fields to update")
|
||
}
|
||
if label != nil && strings.TrimSpace(*label) == "" {
|
||
return AdminApiKey{}, errors.New("label is required")
|
||
}
|
||
if permissions != nil && len(permissions) == 0 {
|
||
return AdminApiKey{}, errors.New("permissions_users is required")
|
||
}
|
||
|
||
var key AdminApiKey
|
||
row := s.pool.QueryRow(ctx, `
|
||
UPDATE api_keys
|
||
SET label = COALESCE($2, label),
|
||
permissions_users = COALESCE($3, permissions_users)
|
||
WHERE id = $1
|
||
RETURNING id, label, permissions_users, status, created_at, last_used_at
|
||
`, id, label, permissions)
|
||
if err := row.Scan(&key.ID, &key.Label, &key.PermissionsUsers, &key.Status, &key.CreatedAt, &key.LastUsedAt); err != nil {
|
||
if errors.Is(err, pgx.ErrNoRows) {
|
||
return AdminApiKey{}, ErrNotFound
|
||
}
|
||
return AdminApiKey{}, err
|
||
}
|
||
return key, nil
|
||
}
|
||
|
||
func (s *PostgresKeyStore) Revoke(ctx context.Context, id string) (AdminApiKey, error) {
|
||
var key AdminApiKey
|
||
row := s.pool.QueryRow(ctx, `
|
||
UPDATE api_keys
|
||
SET status = 'revoked'
|
||
WHERE id = $1
|
||
RETURNING id, label, permissions_users, status, created_at, last_used_at
|
||
`, id)
|
||
if err := row.Scan(&key.ID, &key.Label, &key.PermissionsUsers, &key.Status, &key.CreatedAt, &key.LastUsedAt); err != nil {
|
||
if errors.Is(err, pgx.ErrNoRows) {
|
||
return AdminApiKey{}, ErrNotFound
|
||
}
|
||
return AdminApiKey{}, err
|
||
}
|
||
return key, nil
|
||
}
|
||
|
||
func (s *PostgresKeyStore) Authenticate(ctx context.Context, apiKey string) (*AdminApiKey, bool, error) {
|
||
keyHash := hashKey(apiKey)
|
||
var key AdminApiKey
|
||
err := s.pool.QueryRow(ctx, `
|
||
SELECT id, label, permissions_users, status, created_at, last_used_at
|
||
FROM api_keys
|
||
WHERE api_key_hash = $1
|
||
`, keyHash).Scan(&key.ID, &key.Label, &key.PermissionsUsers, &key.Status, &key.CreatedAt, &key.LastUsedAt)
|
||
if err != nil {
|
||
if errors.Is(err, pgx.ErrNoRows) {
|
||
return nil, false, nil
|
||
}
|
||
return nil, false, err
|
||
}
|
||
if key.Status != "active" {
|
||
return nil, false, nil
|
||
}
|
||
now := time.Now().UTC()
|
||
_, err = s.pool.Exec(ctx, `UPDATE api_keys SET last_used_at = $2 WHERE id = $1`, key.ID, now)
|
||
if err != nil {
|
||
return nil, false, err
|
||
}
|
||
key.LastUsedAt = &now
|
||
return &key, true, nil
|
||
}
|
||
|
||
func generateAPIKey() (string, error) {
|
||
buf := make([]byte, 32)
|
||
if _, err := rand.Read(buf); err != nil {
|
||
return "", err
|
||
}
|
||
return "pgv_" + base64.RawURLEncoding.EncodeToString(buf), nil
|
||
}
|
||
|
||
func generateID() string {
|
||
buf := make([]byte, 16)
|
||
_, _ = rand.Read(buf)
|
||
// format as UUID v4-like
|
||
buf[6] = (buf[6] & 0x0f) | 0x40
|
||
buf[8] = (buf[8] & 0x3f) | 0x80
|
||
return hex.EncodeToString(buf[0:4]) + "-" +
|
||
hex.EncodeToString(buf[4:6]) + "-" +
|
||
hex.EncodeToString(buf[6:8]) + "-" +
|
||
hex.EncodeToString(buf[8:10]) + "-" +
|
||
hex.EncodeToString(buf[10:16])
|
||
}
|
||
|
||
func parseUUID(value string) (string, error) {
|
||
parts := strings.Split(value, "-")
|
||
if len(parts) != 5 {
|
||
return "", errors.New("invalid uuid format")
|
||
}
|
||
lengths := []int{8, 4, 4, 4, 12}
|
||
for i, part := range parts {
|
||
if len(part) != lengths[i] {
|
||
return "", errors.New("invalid uuid format")
|
||
}
|
||
if _, err := hex.DecodeString(part); err != nil {
|
||
return "", errors.New("invalid uuid format")
|
||
}
|
||
}
|
||
return value, nil
|
||
}
|
||
|
||
func hashKey(raw string) string {
|
||
sum := sha256.Sum256([]byte(raw))
|
||
return hex.EncodeToString(sum[:])
|
||
}
|
||
|
||
func withAdminAuth(adminKey string, next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
if adminKey == "" {
|
||
writeError(w, http.StatusServiceUnavailable, "admin_key_not_configured", "ADMIN_API_KEY is not set")
|
||
return
|
||
}
|
||
key := r.Header.Get("X-ADMIN-API-KEY")
|
||
if key == "" || key != adminKey {
|
||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid admin api key")
|
||
return
|
||
}
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
|
||
func withAPIKeyAuth(store APIKeyStore, next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
key := r.Header.Get("X-API-KEY")
|
||
if key == "" {
|
||
writeError(w, http.StatusUnauthorized, "unauthorized", "missing api key")
|
||
return
|
||
}
|
||
k, ok, err := store.Authenticate(r.Context(), key)
|
||
if err != nil {
|
||
writeError(w, http.StatusInternalServerError, "auth_failed", "failed to authenticate api key")
|
||
return
|
||
}
|
||
if !ok || k == nil {
|
||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid api key")
|
||
return
|
||
}
|
||
ctx := context.WithValue(r.Context(), ctxPermissionsUsers, k.PermissionsUsers)
|
||
ctx = context.WithValue(ctx, ctxAPIKeyID, k.ID)
|
||
next.ServeHTTP(w, r.WithContext(ctx))
|
||
})
|
||
}
|
||
|
||
func handleHealth(w http.ResponseWriter, r *http.Request) {
|
||
writeJSON(w, http.StatusOK, HealthResponse{
|
||
Status: "ok",
|
||
Timestamp: time.Now().UTC(),
|
||
})
|
||
}
|
||
|
||
func handleNotImplemented(w http.ResponseWriter, r *http.Request) {
|
||
writeError(w, http.StatusNotImplemented, "not_implemented", "endpoint not implemented yet")
|
||
}
|
||
|
||
func handleKbUpsert(db *pgxpool.Pool) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
if db == nil {
|
||
writeError(w, http.StatusServiceUnavailable, "db_not_configured", "DATABASE_URL is not set")
|
||
return
|
||
}
|
||
permissions, ok := r.Context().Value(ctxPermissionsUsers).([]string)
|
||
if !ok || len(permissions) == 0 {
|
||
writeError(w, http.StatusForbidden, "forbidden", "permissions.users not set")
|
||
return
|
||
}
|
||
var req KbUpsertRequest
|
||
if err := decodeJSON(r, &req); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||
return
|
||
}
|
||
if hasMetadataPermissions(req.Metadata) {
|
||
apiKeyID, _ := r.Context().Value(ctxAPIKeyID).(string)
|
||
permissions, _ := r.Context().Value(ctxPermissionsUsers).([]string)
|
||
if err := recordPermissionsAudit(r.Context(), db, apiKeyID, permissions, r, req.Metadata["permissions"], req.Metadata); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "audit_failed", err.Error())
|
||
return
|
||
}
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "metadata.permissions is not allowed")
|
||
return
|
||
}
|
||
if strings.TrimSpace(req.Content) == "" {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "content is required")
|
||
return
|
||
}
|
||
if strings.TrimSpace(req.Collection) == "" {
|
||
req.Collection = "default"
|
||
}
|
||
if req.ChunkSize == 0 {
|
||
req.ChunkSize = 800
|
||
}
|
||
if req.Overlap == 0 {
|
||
req.Overlap = 100
|
||
}
|
||
if req.ChunkSize < 200 {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "chunk_size is too small")
|
||
return
|
||
}
|
||
if req.Overlap < 0 || req.Overlap >= req.ChunkSize {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "chunk_overlap must be between 0 and chunk_size-1")
|
||
return
|
||
}
|
||
dim := embeddingDim()
|
||
if len(req.Embedding) > 0 && req.AutoChunk {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "embedding cannot be provided when auto_chunk is true")
|
||
return
|
||
}
|
||
if strings.TrimSpace(req.DocID) == "" {
|
||
req.DocID = generateID()
|
||
} else {
|
||
if _, err := parseUUID(req.DocID); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "doc_id must be uuid")
|
||
return
|
||
}
|
||
}
|
||
if req.ChunkIndex < 0 {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "chunk_index must be >= 0")
|
||
return
|
||
}
|
||
if !req.AutoChunk {
|
||
if strings.TrimSpace(req.ID) == "" {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "id is required")
|
||
return
|
||
}
|
||
if _, err := parseUUID(req.ID); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "id must be uuid")
|
||
return
|
||
}
|
||
}
|
||
metadata := req.Metadata
|
||
if metadata == nil {
|
||
metadata = map[string]interface{}{}
|
||
}
|
||
metadata["collection"] = req.Collection
|
||
perms := map[string]interface{}{
|
||
"users": permissions,
|
||
}
|
||
metadata["permissions"] = perms
|
||
metaBytes, err := json.Marshal(metadata)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "metadata must be valid object")
|
||
return
|
||
}
|
||
|
||
if req.AutoChunk || len([]rune(req.Content)) > req.ChunkSize {
|
||
chunks := splitIntoChunks(req.Content, req.ChunkSize, req.Overlap)
|
||
resp := KbUpsertResponse{DocID: req.DocID, Updated: false}
|
||
for i, chunk := range chunks {
|
||
emb, err := createEmbeddingWithDim(r.Context(), chunk, dim)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "embedding_failed", err.Error())
|
||
return
|
||
}
|
||
vectorParam := vectorLiteral(emb)
|
||
id := generateID()
|
||
var updated bool
|
||
err = db.QueryRow(r.Context(), `
|
||
INSERT INTO kb_doc_chunks (id, doc_id, chunk_index, collection, content, metadata, embedding, updated_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6::jsonb, $7::vector, now())
|
||
ON CONFLICT (id) DO UPDATE
|
||
SET doc_id = EXCLUDED.doc_id,
|
||
chunk_index = EXCLUDED.chunk_index,
|
||
collection = EXCLUDED.collection,
|
||
content = EXCLUDED.content,
|
||
metadata = EXCLUDED.metadata,
|
||
embedding = EXCLUDED.embedding,
|
||
updated_at = now()
|
||
RETURNING (xmax <> 0) AS updated
|
||
`, id, req.DocID, i, req.Collection, chunk, string(metaBytes), vectorParam).Scan(&updated)
|
||
if err != nil {
|
||
log.Printf("kb.upsert failed: %v", err)
|
||
writeError(w, http.StatusInternalServerError, "upsert_failed", fmt.Sprintf("failed to upsert document: %v", err))
|
||
return
|
||
}
|
||
resp.Updated = resp.Updated || updated
|
||
resp.Chunks = append(resp.Chunks, struct {
|
||
ID string `json:"id"`
|
||
ChunkIndex int `json:"chunk_index"`
|
||
}{ID: id, ChunkIndex: i})
|
||
}
|
||
if len(resp.Chunks) > 0 {
|
||
resp.ID = resp.Chunks[0].ID
|
||
}
|
||
writeJSON(w, http.StatusOK, resp)
|
||
return
|
||
}
|
||
|
||
if len(req.Embedding) == 0 {
|
||
emb, err := createEmbeddingWithDim(r.Context(), req.Content, dim)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "embedding_failed", err.Error())
|
||
return
|
||
}
|
||
req.Embedding = emb
|
||
}
|
||
if len(req.Embedding) != dim {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("embedding must be length %d", dim))
|
||
return
|
||
}
|
||
vectorParam := vectorLiteral(req.Embedding)
|
||
var updated bool
|
||
err = db.QueryRow(r.Context(), `
|
||
INSERT INTO kb_doc_chunks (id, doc_id, chunk_index, collection, content, metadata, embedding, updated_at)
|
||
VALUES ($1, $2, $3, $4, $5, $6::jsonb, $7::vector, now())
|
||
ON CONFLICT (id) DO UPDATE
|
||
SET doc_id = EXCLUDED.doc_id,
|
||
chunk_index = EXCLUDED.chunk_index,
|
||
collection = EXCLUDED.collection,
|
||
content = EXCLUDED.content,
|
||
metadata = EXCLUDED.metadata,
|
||
embedding = EXCLUDED.embedding,
|
||
updated_at = now()
|
||
RETURNING (xmax <> 0) AS updated
|
||
`, req.ID, req.DocID, req.ChunkIndex, req.Collection, req.Content, string(metaBytes), vectorParam).Scan(&updated)
|
||
if err != nil {
|
||
log.Printf("kb.upsert failed: %v", err)
|
||
writeError(w, http.StatusInternalServerError, "upsert_failed", fmt.Sprintf("failed to upsert document: %v", err))
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, KbUpsertResponse{ID: req.ID, DocID: req.DocID, Updated: updated})
|
||
}
|
||
}
|
||
|
||
func handleKbSearch(db *pgxpool.Pool) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
if db == nil {
|
||
writeError(w, http.StatusServiceUnavailable, "db_not_configured", "DATABASE_URL is not set")
|
||
return
|
||
}
|
||
logConnectionInfo(db)
|
||
var req KbSearchRequest
|
||
if err := decodeJSON(r, &req); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||
return
|
||
}
|
||
if strings.TrimSpace(req.Query) == "" {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "query is required")
|
||
return
|
||
}
|
||
if req.TopK == 0 {
|
||
req.TopK = 5
|
||
}
|
||
if req.TopK < 1 || req.TopK > 50 {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "top_k must be between 1 and 50")
|
||
return
|
||
}
|
||
if strings.TrimSpace(req.Collection) == "" && len(req.Collections) == 0 {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "collection or collections is required")
|
||
return
|
||
}
|
||
if strings.TrimSpace(req.Collection) != "" && len(req.Collections) > 0 {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "collection and collections are mutually exclusive")
|
||
return
|
||
}
|
||
if strings.TrimSpace(req.Collection) != "" {
|
||
if err := validateCollection(req.Collection); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "collection is invalid")
|
||
return
|
||
}
|
||
}
|
||
if len(req.Collections) > 0 {
|
||
for _, name := range req.Collections {
|
||
if err := validateCollection(name); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "collections contains invalid value")
|
||
return
|
||
}
|
||
}
|
||
}
|
||
dim := embeddingDim()
|
||
if len(req.QueryEmbedding) == 0 {
|
||
emb, err := createEmbeddingWithDim(r.Context(), req.Query, dim)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "embedding_failed", err.Error())
|
||
return
|
||
}
|
||
req.QueryEmbedding = emb
|
||
}
|
||
if len(req.QueryEmbedding) != dim {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("query_embedding must be length %d", dim))
|
||
return
|
||
}
|
||
|
||
permissions, ok := r.Context().Value(ctxPermissionsUsers).([]string)
|
||
if !ok || len(permissions) == 0 {
|
||
writeError(w, http.StatusForbidden, "forbidden", "permissions.users not set")
|
||
return
|
||
}
|
||
|
||
permFilter := map[string]interface{}{
|
||
"contains": map[string]interface{}{
|
||
"metadata.permissions.users": permissions,
|
||
},
|
||
}
|
||
mergedFilter := permFilter
|
||
if req.Filter != nil {
|
||
if filterHasMetadataPermissions(req.Filter) {
|
||
apiKeyID, _ := r.Context().Value(ctxAPIKeyID).(string)
|
||
permissions, _ := r.Context().Value(ctxPermissionsUsers).([]string)
|
||
if err := recordPermissionsAudit(r.Context(), db, apiKeyID, permissions, r, map[string]interface{}{"filter": req.Filter}, map[string]interface{}{"filter": req.Filter}); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "audit_failed", err.Error())
|
||
return
|
||
}
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "metadata.permissions filter is not allowed")
|
||
return
|
||
}
|
||
mergedFilter = map[string]interface{}{
|
||
"and": []interface{}{req.Filter, permFilter},
|
||
}
|
||
}
|
||
|
||
whereSQL, params, nextIdx, err := buildFilterSQL(mergedFilter, 3)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||
return
|
||
}
|
||
|
||
vectorParam := vectorLiteral(req.QueryEmbedding)
|
||
if strings.TrimSpace(req.Collection) != "" {
|
||
params = append([]interface{}{vectorParam, req.Collection}, params...)
|
||
} else {
|
||
params = append([]interface{}{vectorParam, req.Collections}, params...)
|
||
}
|
||
params = append(params, req.TopK)
|
||
limitIdx := nextIdx
|
||
|
||
var query string
|
||
if strings.TrimSpace(req.Collection) != "" {
|
||
query = fmt.Sprintf(`
|
||
SELECT id::text, collection, content, metadata, (embedding <=> $1) AS score
|
||
FROM kb_doc_chunks
|
||
WHERE collection = $2 AND %s
|
||
ORDER BY embedding <=> $1
|
||
LIMIT $%d
|
||
`, whereSQL, limitIdx)
|
||
} else {
|
||
query = fmt.Sprintf(`
|
||
SELECT id::text, collection, content, metadata, (embedding <=> $1) AS score
|
||
FROM kb_doc_chunks
|
||
WHERE collection = ANY($2) AND %s
|
||
ORDER BY embedding <=> $1
|
||
LIMIT $%d
|
||
`, whereSQL, limitIdx)
|
||
}
|
||
|
||
rows, err := db.Query(r.Context(), query, params...)
|
||
if err != nil {
|
||
log.Printf("kb.search query failed: %v", err)
|
||
writeError(w, http.StatusInternalServerError, "search_failed", fmt.Sprintf("failed to execute search: %v", err))
|
||
return
|
||
}
|
||
defer rows.Close()
|
||
|
||
results := make([]KbSearchResult, 0)
|
||
for rows.Next() {
|
||
var res KbSearchResult
|
||
var metaBytes []byte
|
||
if err := rows.Scan(&res.ID, &res.Collection, &res.Content, &metaBytes, &res.Score); err != nil {
|
||
log.Printf("kb.search scan failed: %v", err)
|
||
writeError(w, http.StatusInternalServerError, "search_failed", fmt.Sprintf("failed to read search results: %v", err))
|
||
return
|
||
}
|
||
if math.IsNaN(res.Score) || math.IsInf(res.Score, 0) {
|
||
res.Score = 0
|
||
} else {
|
||
res.Score = distanceToSimilarity(res.Score)
|
||
}
|
||
if len(metaBytes) > 0 {
|
||
if err := json.Unmarshal(metaBytes, &res.Metadata); err != nil {
|
||
log.Printf("kb.search metadata unmarshal failed: %v", err)
|
||
writeError(w, http.StatusInternalServerError, "search_failed", fmt.Sprintf("failed to parse search results: %v", err))
|
||
return
|
||
}
|
||
}
|
||
results = append(results, res)
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
log.Printf("kb.search rows failed: %v", err)
|
||
writeError(w, http.StatusInternalServerError, "search_failed", "failed to read search results")
|
||
return
|
||
}
|
||
|
||
writeJSON(w, http.StatusOK, KbSearchResponse{Results: results})
|
||
}
|
||
}
|
||
|
||
func handleKbDelete(db *pgxpool.Pool) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
if db == nil {
|
||
writeError(w, http.StatusServiceUnavailable, "db_not_configured", "DATABASE_URL is not set")
|
||
return
|
||
}
|
||
permissions, ok := r.Context().Value(ctxPermissionsUsers).([]string)
|
||
if !ok || len(permissions) == 0 {
|
||
writeError(w, http.StatusForbidden, "forbidden", "permissions.users not set")
|
||
return
|
||
}
|
||
apiKeyID, _ := r.Context().Value(ctxAPIKeyID).(string)
|
||
var req KbDeleteRequest
|
||
if err := decodeJSON(r, &req); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||
return
|
||
}
|
||
if strings.TrimSpace(req.ID) == "" && req.Filter == nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "id or filter is required")
|
||
return
|
||
}
|
||
if strings.TrimSpace(req.ID) != "" {
|
||
if _, err := parseUUID(req.ID); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "id must be uuid")
|
||
return
|
||
}
|
||
}
|
||
if strings.TrimSpace(req.DocID) != "" {
|
||
if _, err := parseUUID(req.DocID); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "doc_id must be uuid")
|
||
return
|
||
}
|
||
}
|
||
if strings.TrimSpace(req.Collection) != "" {
|
||
if err := validateCollection(req.Collection); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "collection is invalid")
|
||
return
|
||
}
|
||
}
|
||
|
||
permFilter := map[string]interface{}{
|
||
"contains": map[string]interface{}{
|
||
"metadata.permissions.users": permissions,
|
||
},
|
||
}
|
||
mergedFilter := permFilter
|
||
if req.Filter != nil {
|
||
if filterHasMetadataPermissions(req.Filter) {
|
||
apiKeyID, _ := r.Context().Value(ctxAPIKeyID).(string)
|
||
permissions, _ := r.Context().Value(ctxPermissionsUsers).([]string)
|
||
if err := recordPermissionsAudit(r.Context(), db, apiKeyID, permissions, r, map[string]interface{}{"filter": req.Filter}, map[string]interface{}{"filter": req.Filter}); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "audit_failed", err.Error())
|
||
return
|
||
}
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "metadata.permissions filter is not allowed")
|
||
return
|
||
}
|
||
mergedFilter = map[string]interface{}{
|
||
"and": []interface{}{req.Filter, permFilter},
|
||
}
|
||
}
|
||
|
||
clauses := make([]string, 0, 4)
|
||
params := make([]interface{}, 0, 4)
|
||
idx := 1
|
||
if strings.TrimSpace(req.ID) != "" {
|
||
clauses = append(clauses, fmt.Sprintf("id = $%d", idx))
|
||
params = append(params, req.ID)
|
||
idx++
|
||
}
|
||
if strings.TrimSpace(req.DocID) != "" {
|
||
clauses = append(clauses, fmt.Sprintf("doc_id = $%d", idx))
|
||
params = append(params, req.DocID)
|
||
idx++
|
||
}
|
||
if strings.TrimSpace(req.Collection) != "" {
|
||
clauses = append(clauses, fmt.Sprintf("collection = $%d", idx))
|
||
params = append(params, req.Collection)
|
||
idx++
|
||
}
|
||
|
||
whereSQL, filterParams, _, err := buildFilterSQL(mergedFilter, idx)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||
return
|
||
}
|
||
clauses = append(clauses, whereSQL)
|
||
params = append(params, filterParams...)
|
||
|
||
if req.DryRun {
|
||
query := fmt.Sprintf("SELECT COUNT(*) FROM kb_doc_chunks WHERE %s", strings.Join(clauses, " AND "))
|
||
var count int64
|
||
if err := db.QueryRow(r.Context(), query, params...).Scan(&count); err != nil {
|
||
log.Printf("kb.delete dry-run failed: %v", err)
|
||
writeError(w, http.StatusInternalServerError, "delete_failed", fmt.Sprintf("failed to count: %v", err))
|
||
return
|
||
}
|
||
if err := recordDeleteAudit(r.Context(), db, apiKeyID, permissions, req, count); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "audit_failed", err.Error())
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, KbDeleteResponse{Deleted: count})
|
||
return
|
||
}
|
||
|
||
query := fmt.Sprintf("DELETE FROM kb_doc_chunks WHERE %s", strings.Join(clauses, " AND "))
|
||
ct, err := db.Exec(r.Context(), query, params...)
|
||
if err != nil {
|
||
log.Printf("kb.delete failed: %v", err)
|
||
writeError(w, http.StatusInternalServerError, "delete_failed", fmt.Sprintf("failed to delete: %v", err))
|
||
return
|
||
}
|
||
if err := recordDeleteAudit(r.Context(), db, apiKeyID, permissions, req, ct.RowsAffected()); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "audit_failed", err.Error())
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, KbDeleteResponse{Deleted: ct.RowsAffected()})
|
||
}
|
||
}
|
||
|
||
func recordDeleteAudit(ctx context.Context, db *pgxpool.Pool, apiKeyID string, permissions []string, req KbDeleteRequest, deleted int64) error {
|
||
var filterJSON any
|
||
if req.Filter != nil {
|
||
raw, err := json.Marshal(req.Filter)
|
||
if err != nil {
|
||
return fmt.Errorf("invalid filter for audit: %w", err)
|
||
}
|
||
filterJSON = string(raw)
|
||
}
|
||
var idParam any
|
||
if strings.TrimSpace(req.ID) != "" {
|
||
idParam = req.ID
|
||
}
|
||
var docIDParam any
|
||
if strings.TrimSpace(req.DocID) != "" {
|
||
docIDParam = req.DocID
|
||
}
|
||
_, err := db.Exec(ctx, `
|
||
INSERT INTO kb_delete_audit (
|
||
id, api_key_id, permissions_users, dry_run, deleted_count, target_id, target_doc_id, collection, filter
|
||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||
`, generateID(), apiKeyID, permissions, req.DryRun, deleted, idParam, docIDParam, req.Collection, filterJSON)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to write delete audit: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func recordPermissionsAudit(ctx context.Context, db *pgxpool.Pool, apiKeyID string, permissions []string, r *http.Request, providedPermissions any, providedMetadata any) error {
|
||
if strings.TrimSpace(apiKeyID) == "" {
|
||
return errors.New("api_key_id is missing for audit")
|
||
}
|
||
if len(permissions) == 0 {
|
||
return errors.New("permissions.users is missing for audit")
|
||
}
|
||
if providedPermissions == nil {
|
||
return errors.New("provided_permissions is missing for audit")
|
||
}
|
||
permsJSON, err := json.Marshal(providedPermissions)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to marshal provided permissions: %w", err)
|
||
}
|
||
if providedMetadata == nil {
|
||
return errors.New("provided_metadata is missing for audit")
|
||
}
|
||
metadataJSON, err := json.Marshal(providedMetadata)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to marshal metadata: %w", err)
|
||
}
|
||
remoteAddr := r.RemoteAddr
|
||
userAgent := r.UserAgent()
|
||
endpoint := r.Method + " " + r.URL.Path
|
||
|
||
log.Printf("permissions audit: api_key_id=%s endpoint=%s remote=%s", apiKeyID, endpoint, remoteAddr)
|
||
|
||
_, err = db.Exec(ctx, `
|
||
INSERT INTO kb_permissions_audit (
|
||
id, api_key_id, permissions_users, requested_at, endpoint, remote_addr, user_agent, provided_permissions, provided_metadata
|
||
) VALUES ($1, $2, $3, now(), $4, $5, $6, $7::jsonb, $8::jsonb)
|
||
`, generateID(), apiKeyID, permissions, endpoint, remoteAddr, userAgent, string(permsJSON), string(metadataJSON))
|
||
if err != nil {
|
||
return fmt.Errorf("failed to write permissions audit: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func handleCollections(db *pgxpool.Pool) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
if db == nil {
|
||
writeError(w, http.StatusServiceUnavailable, "db_not_configured", "DATABASE_URL is not set")
|
||
return
|
||
}
|
||
if r.Method != http.MethodGet {
|
||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
permissions, ok := r.Context().Value(ctxPermissionsUsers).([]string)
|
||
if !ok || len(permissions) == 0 {
|
||
writeError(w, http.StatusForbidden, "forbidden", "permissions.users not set")
|
||
return
|
||
}
|
||
permFilter := map[string]interface{}{
|
||
"contains": map[string]interface{}{
|
||
"metadata.permissions.users": permissions,
|
||
},
|
||
}
|
||
whereSQL, params, _, err := buildFilterSQL(permFilter, 1)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||
return
|
||
}
|
||
query := fmt.Sprintf(`
|
||
SELECT collection, COUNT(*) AS cnt
|
||
FROM kb_doc_chunks
|
||
WHERE %s
|
||
GROUP BY collection
|
||
ORDER BY collection ASC
|
||
`, whereSQL)
|
||
rows, err := db.Query(r.Context(), query, params...)
|
||
if err != nil {
|
||
writeError(w, http.StatusInternalServerError, "collections_failed", fmt.Sprintf("failed to list collections: %v", err))
|
||
return
|
||
}
|
||
defer rows.Close()
|
||
items := make([]CollectionItem, 0)
|
||
for rows.Next() {
|
||
var item CollectionItem
|
||
if err := rows.Scan(&item.Name, &item.Count); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "collections_failed", "failed to read collections")
|
||
return
|
||
}
|
||
items = append(items, item)
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "collections_failed", "failed to read collections")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, CollectionsResponse{Items: items})
|
||
}
|
||
}
|
||
|
||
func handleAdminCollections(db *pgxpool.Pool) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
if db == nil {
|
||
writeError(w, http.StatusServiceUnavailable, "db_not_configured", "DATABASE_URL is not set")
|
||
return
|
||
}
|
||
if r.Method != http.MethodGet {
|
||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
rows, err := db.Query(r.Context(), `
|
||
SELECT collection, COUNT(*) AS cnt
|
||
FROM kb_doc_chunks
|
||
GROUP BY collection
|
||
ORDER BY collection ASC
|
||
`)
|
||
if err != nil {
|
||
writeError(w, http.StatusInternalServerError, "collections_failed", fmt.Sprintf("failed to list collections: %v", err))
|
||
return
|
||
}
|
||
defer rows.Close()
|
||
items := make([]CollectionItem, 0)
|
||
for rows.Next() {
|
||
var item CollectionItem
|
||
if err := rows.Scan(&item.Name, &item.Count); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "collections_failed", "failed to read collections")
|
||
return
|
||
}
|
||
items = append(items, item)
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "collections_failed", "failed to read collections")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, CollectionsResponse{Items: items})
|
||
}
|
||
}
|
||
|
||
func handleAdminPermissionsAudit(db *pgxpool.Pool) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
if db == nil {
|
||
writeError(w, http.StatusServiceUnavailable, "db_not_configured", "DATABASE_URL is not set")
|
||
return
|
||
}
|
||
if r.Method != http.MethodGet {
|
||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
q := r.URL.Query()
|
||
since, err := parseTimeParam(q.Get("since"))
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "since must be RFC3339")
|
||
return
|
||
}
|
||
until, err := parseTimeParam(q.Get("until"))
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "until must be RFC3339")
|
||
return
|
||
}
|
||
limit, err := parseLimit(q.Get("limit"), 100, 1000)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||
return
|
||
}
|
||
apiKeyID := strings.TrimSpace(q.Get("api_key_id"))
|
||
|
||
clauses := make([]string, 0, 4)
|
||
params := make([]interface{}, 0, 4)
|
||
idx := 1
|
||
if since != nil {
|
||
clauses = append(clauses, fmt.Sprintf("requested_at >= $%d", idx))
|
||
params = append(params, *since)
|
||
idx++
|
||
}
|
||
if until != nil {
|
||
clauses = append(clauses, fmt.Sprintf("requested_at <= $%d", idx))
|
||
params = append(params, *until)
|
||
idx++
|
||
}
|
||
if apiKeyID != "" {
|
||
clauses = append(clauses, fmt.Sprintf("api_key_id = $%d", idx))
|
||
params = append(params, apiKeyID)
|
||
idx++
|
||
}
|
||
whereSQL := "TRUE"
|
||
if len(clauses) > 0 {
|
||
whereSQL = strings.Join(clauses, " AND ")
|
||
}
|
||
|
||
query := fmt.Sprintf(`
|
||
SELECT id::text, api_key_id, permissions_users, requested_at, endpoint, remote_addr, user_agent, provided_permissions, provided_metadata
|
||
FROM kb_permissions_audit
|
||
WHERE %s
|
||
ORDER BY requested_at DESC
|
||
LIMIT $%d
|
||
`, whereSQL, idx)
|
||
params = append(params, limit)
|
||
|
||
rows, err := db.Query(r.Context(), query, params...)
|
||
if err != nil {
|
||
writeError(w, http.StatusInternalServerError, "audit_failed", fmt.Sprintf("failed to read audit logs: %v", err))
|
||
return
|
||
}
|
||
defer rows.Close()
|
||
|
||
items := make([]PermissionsAuditItem, 0)
|
||
for rows.Next() {
|
||
var item PermissionsAuditItem
|
||
var permsJSON []byte
|
||
var metaJSON []byte
|
||
if err := rows.Scan(&item.ID, &item.APIKeyID, &item.PermissionsUsers, &item.RequestedAt, &item.Endpoint, &item.RemoteAddr, &item.UserAgent, &permsJSON, &metaJSON); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "audit_failed", "failed to read audit logs")
|
||
return
|
||
}
|
||
if len(permsJSON) > 0 {
|
||
if err := json.Unmarshal(permsJSON, &item.ProvidedPermissions); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "audit_failed", "failed to parse audit logs")
|
||
return
|
||
}
|
||
}
|
||
if len(metaJSON) > 0 {
|
||
if err := json.Unmarshal(metaJSON, &item.ProvidedMetadata); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "audit_failed", "failed to parse audit logs")
|
||
return
|
||
}
|
||
}
|
||
items = append(items, item)
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
writeError(w, http.StatusInternalServerError, "audit_failed", "failed to read audit logs")
|
||
return
|
||
}
|
||
|
||
writeJSON(w, http.StatusOK, PermissionsAuditResponse{Items: items})
|
||
}
|
||
}
|
||
|
||
func handleAdminApiKeys(store APIKeyStore) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
path := strings.TrimPrefix(r.URL.Path, "/admin/api-keys")
|
||
path = strings.TrimPrefix(path, "/")
|
||
if path == "" {
|
||
switch r.Method {
|
||
case http.MethodGet:
|
||
items, err := store.List(r.Context())
|
||
if err != nil {
|
||
writeError(w, http.StatusInternalServerError, "list_failed", "failed to list api keys")
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, AdminApiKeyListResponse{Items: items})
|
||
case http.MethodPost:
|
||
var req AdminApiKeyCreateRequest
|
||
if err := decodeJSON(r, &req); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||
return
|
||
}
|
||
rawKey, key, err := store.Create(r.Context(), req.Label, req.PermissionsUsers)
|
||
if err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, AdminApiKeyCreateResponse{
|
||
APIKey: rawKey,
|
||
Key: key,
|
||
})
|
||
default:
|
||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||
}
|
||
return
|
||
}
|
||
|
||
parts := strings.Split(path, "/")
|
||
if len(parts) >= 2 && parts[1] == "revoke" {
|
||
if r.Method != http.MethodPost {
|
||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
key, err := store.Revoke(r.Context(), parts[0])
|
||
if err != nil {
|
||
if errors.Is(err, ErrNotFound) {
|
||
writeError(w, http.StatusNotFound, "not_found", "api key not found")
|
||
return
|
||
}
|
||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, key)
|
||
return
|
||
}
|
||
|
||
if len(parts) == 1 && r.Method == http.MethodPatch {
|
||
var req AdminApiKeyUpdateRequest
|
||
if err := decodeJSON(r, &req); err != nil {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||
return
|
||
}
|
||
updated, err := store.Update(r.Context(), parts[0], req.Label, req.PermissionsUsers)
|
||
if err != nil {
|
||
if errors.Is(err, ErrNotFound) {
|
||
writeError(w, http.StatusNotFound, "not_found", "api key not found")
|
||
return
|
||
}
|
||
writeError(w, http.StatusBadRequest, "invalid_request", err.Error())
|
||
return
|
||
}
|
||
writeJSON(w, http.StatusOK, updated)
|
||
return
|
||
}
|
||
|
||
w.WriteHeader(http.StatusNotFound)
|
||
}
|
||
}
|
||
|
||
func handleAdminUI(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||
_, _ = w.Write([]byte(adminHTML))
|
||
}
|
||
|
||
func writeJSON(w http.ResponseWriter, status int, v interface{}) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
payload, err := json.Marshal(v)
|
||
if err != nil {
|
||
w.WriteHeader(http.StatusInternalServerError)
|
||
_, _ = w.Write([]byte(`{"error":"failed to encode response","code":"encode_failed"}`))
|
||
return
|
||
}
|
||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(payload)))
|
||
w.WriteHeader(status)
|
||
if _, err := w.Write(payload); err != nil {
|
||
log.Printf("writeJSON failed: %v", err)
|
||
}
|
||
}
|
||
|
||
func writeError(w http.ResponseWriter, status int, code, msg string) {
|
||
writeJSON(w, status, ErrorResponse{Error: msg, Code: code})
|
||
}
|
||
|
||
func decodeJSON(r *http.Request, v interface{}) error {
|
||
dec := json.NewDecoder(r.Body)
|
||
dec.DisallowUnknownFields()
|
||
if err := dec.Decode(v); err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
const adminHTML = `<!doctype html>
|
||
<html lang="ja">
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
<title>pgvecter API Admin</title>
|
||
<style>
|
||
body { font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, sans-serif; margin: 24px; }
|
||
h1 { margin-bottom: 12px; }
|
||
section { margin-bottom: 24px; }
|
||
table { border-collapse: collapse; width: 100%; }
|
||
th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
|
||
th { background: #f5f5f5; }
|
||
input, textarea, button { font-size: 14px; }
|
||
textarea { width: 100%; height: 60px; }
|
||
.row { display: flex; gap: 12px; }
|
||
.row > div { flex: 1; }
|
||
.muted { color: #666; font-size: 12px; }
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<h1>pgvecter API Admin</h1>
|
||
<p class="muted">X-ADMIN-API-KEY を使って管理APIを呼び出します(admin_key クエリは localhost のみ)。</p>
|
||
<div id="status" class="muted"></div>
|
||
|
||
<section id="editSection">
|
||
<h2>接続</h2>
|
||
<div class="row">
|
||
<div>
|
||
<label>API Base URL</label><br />
|
||
<input id="baseUrl" value="http://localhost:8080" style="width:100%" />
|
||
</div>
|
||
<div>
|
||
<label>Admin API Key</label><br />
|
||
<input id="adminKey" type="password" style="width:100%" />
|
||
</div>
|
||
</div>
|
||
<button onclick="loadKeys()">一覧を取得</button>
|
||
</section>
|
||
|
||
<section>
|
||
<h2>発行</h2>
|
||
<div class="row">
|
||
<div>
|
||
<label>Label</label><br />
|
||
<input id="newLabel" style="width:100%" />
|
||
</div>
|
||
<div>
|
||
<label>permissions.users(改行区切り)</label><br />
|
||
<textarea id="newUsers"></textarea>
|
||
</div>
|
||
</div>
|
||
<button onclick="createKey()">発行</button>
|
||
<pre id="issuedKey"></pre>
|
||
<button onclick="copyIssuedKey()">コピー</button>
|
||
</section>
|
||
|
||
<section>
|
||
<h2>APIキー一覧</h2>
|
||
<table id="keysTable">
|
||
<thead>
|
||
<tr>
|
||
<th>ID</th>
|
||
<th>Label</th>
|
||
<th>Users</th>
|
||
<th>Status</th>
|
||
<th>Created</th>
|
||
<th>Last Used</th>
|
||
<th>Actions</th>
|
||
</tr>
|
||
</thead>
|
||
<tbody></tbody>
|
||
</table>
|
||
</section>
|
||
|
||
<section>
|
||
<h2>コレクション一覧</h2>
|
||
<table id="collectionsTable">
|
||
<thead>
|
||
<tr>
|
||
<th>Name</th>
|
||
<th>Count</th>
|
||
</tr>
|
||
</thead>
|
||
<tbody></tbody>
|
||
</table>
|
||
</section>
|
||
|
||
<section>
|
||
<h2>編集</h2>
|
||
<div class="row">
|
||
<div>
|
||
<label>Key ID</label><br />
|
||
<input id="editId" style="width:100%" />
|
||
</div>
|
||
<div>
|
||
<label>Label</label><br />
|
||
<input id="editLabel" style="width:100%" />
|
||
</div>
|
||
<div>
|
||
<label>permissions.users(改行区切り)</label><br />
|
||
<textarea id="editUsers"></textarea>
|
||
</div>
|
||
</div>
|
||
<button onclick="updateKey()">更新</button>
|
||
</section>
|
||
|
||
<script>
|
||
function getBaseUrl() { return document.getElementById('baseUrl').value.trim(); }
|
||
function getAdminKey() {
|
||
const input = document.getElementById('adminKey').value.trim();
|
||
if (input) { return input; }
|
||
const host = window.location.hostname;
|
||
if (host === 'localhost' || host === '127.0.0.1' || host === '::1') {
|
||
const params = new URLSearchParams(window.location.search);
|
||
return params.get('admin_key') || '';
|
||
}
|
||
return '';
|
||
}
|
||
function headers() {
|
||
return { 'Content-Type': 'application/json', 'X-ADMIN-API-KEY': getAdminKey() };
|
||
}
|
||
function parseUsers(text) {
|
||
return text.split('\n').map(s => s.trim()).filter(Boolean);
|
||
}
|
||
async function loadKeys() {
|
||
const res = await fetch(getBaseUrl() + '/admin/api-keys', { headers: headers() });
|
||
if (!res.ok) { setStatus('一覧取得に失敗しました'); return; }
|
||
const data = await res.json();
|
||
const tbody = document.querySelector('#keysTable tbody');
|
||
tbody.innerHTML = '';
|
||
for (const item of data.items) {
|
||
const tr = document.createElement('tr');
|
||
tr.innerHTML =
|
||
'<td>' + item.id + '</td>' +
|
||
'<td>' + item.label + '</td>' +
|
||
'<td>' + (item.permissions_users || []).join(', ') + '</td>' +
|
||
'<td>' + item.status + '</td>' +
|
||
'<td>' + item.created_at + '</td>' +
|
||
'<td>' + (item.last_used_at || '') + '</td>' +
|
||
'<td><button class="revoke">失効</button> <button class="edit">編集</button></td>';
|
||
tr.querySelector('button.revoke').addEventListener('click', () => revokeKey(item.id));
|
||
tr.querySelector('button.edit').addEventListener('click', () => fillEditForm(item));
|
||
tbody.appendChild(tr);
|
||
}
|
||
}
|
||
async function loadCollections() {
|
||
const res = await fetch(getBaseUrl() + '/admin/collections', { headers: headers() });
|
||
if (!res.ok) { setStatus('コレクション取得に失敗しました', 'error'); return; }
|
||
const data = await res.json();
|
||
const tbody = document.querySelector('#collectionsTable tbody');
|
||
tbody.innerHTML = '';
|
||
for (const item of data.items || []) {
|
||
const tr = document.createElement('tr');
|
||
tr.innerHTML =
|
||
'<td>' + item.name + '</td>' +
|
||
'<td>' + item.count + '</td>';
|
||
tbody.appendChild(tr);
|
||
}
|
||
}
|
||
async function createKey() {
|
||
const label = document.getElementById('newLabel').value.trim();
|
||
const users = parseUsers(document.getElementById('newUsers').value);
|
||
const res = await fetch(getBaseUrl() + '/admin/api-keys', {
|
||
method: 'POST',
|
||
headers: headers(),
|
||
body: JSON.stringify({ label: label, permissions_users: users })
|
||
});
|
||
if (!res.ok) { setStatus('発行に失敗しました'); return; }
|
||
const data = await res.json();
|
||
document.getElementById('issuedKey').textContent = 'API Key: ' + data.api_key;
|
||
setStatus('発行しました', 'ok');
|
||
await loadKeys();
|
||
await loadCollections();
|
||
}
|
||
async function revokeKey(id) {
|
||
const res = await fetch(getBaseUrl() + '/admin/api-keys/' + id + '/revoke', {
|
||
method: 'POST',
|
||
headers: headers()
|
||
});
|
||
if (!res.ok) { setStatus('失効に失敗しました', 'error'); return; }
|
||
setStatus('失効しました', 'ok');
|
||
await loadKeys();
|
||
await loadCollections();
|
||
}
|
||
function fillEditForm(item) {
|
||
document.getElementById('editId').value = item.id;
|
||
document.getElementById('editLabel').value = item.label;
|
||
document.getElementById('editUsers').value = (item.permissions_users || []).join('\n');
|
||
const section = document.getElementById('editSection');
|
||
if (section && section.scrollIntoView) {
|
||
section.scrollIntoView({ behavior: 'smooth', block: 'start' });
|
||
}
|
||
}
|
||
async function updateKey() {
|
||
const id = document.getElementById('editId').value.trim();
|
||
const label = document.getElementById('editLabel').value.trim();
|
||
const users = parseUsers(document.getElementById('editUsers').value);
|
||
if (!id) { setStatus('Key IDが必要です', 'error'); return; }
|
||
const res = await fetch(getBaseUrl() + '/admin/api-keys/' + id, {
|
||
method: 'PATCH',
|
||
headers: headers(),
|
||
body: JSON.stringify({ label: label, permissions_users: users })
|
||
});
|
||
if (!res.ok) { setStatus('更新に失敗しました', 'error'); return; }
|
||
setStatus('更新しました', 'ok');
|
||
await loadKeys();
|
||
await loadCollections();
|
||
}
|
||
function setStatus(msg, type) {
|
||
const el = document.getElementById('status');
|
||
if (!el) { return; }
|
||
el.textContent = msg;
|
||
el.style.color = type === 'error' ? '#b00020' : '#0b6e4f';
|
||
clearTimeout(window.__statusTimer);
|
||
window.__statusTimer = setTimeout(() => { el.textContent = ''; }, 3000);
|
||
}
|
||
async function copyIssuedKey() {
|
||
const text = document.getElementById('issuedKey').textContent.replace('API Key: ', '').trim();
|
||
if (!text) { setStatus('コピー対象がありません', 'error'); return; }
|
||
try {
|
||
await navigator.clipboard.writeText(text);
|
||
setStatus('APIキーをコピーしました', 'ok');
|
||
} catch (e) {
|
||
setStatus('コピーに失敗しました', 'error');
|
||
}
|
||
}
|
||
window.addEventListener('DOMContentLoaded', () => {
|
||
loadKeys();
|
||
loadCollections();
|
||
});
|
||
</script>
|
||
</body>
|
||
</html>`
|
||
|
||
func withAdminUIAuth(adminKey string, next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
if adminKey == "" {
|
||
writeError(w, http.StatusServiceUnavailable, "admin_key_not_configured", "ADMIN_API_KEY is not set")
|
||
return
|
||
}
|
||
key := r.Header.Get("X-ADMIN-API-KEY")
|
||
if key == "" {
|
||
if isLocalhostRequest(r) {
|
||
key = r.URL.Query().Get("admin_key")
|
||
} else if r.URL.Query().Get("admin_key") != "" {
|
||
writeError(w, http.StatusBadRequest, "invalid_request", "admin_key query parameter is only allowed from localhost")
|
||
return
|
||
}
|
||
}
|
||
if key == "" || key != adminKey {
|
||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid admin api key")
|
||
return
|
||
}
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
|
||
func buildFilterSQL(filter map[string]interface{}, startIdx int) (string, []interface{}, int, error) {
|
||
if filter == nil || len(filter) == 0 {
|
||
return "TRUE", nil, startIdx, nil
|
||
}
|
||
if v, ok := filter["and"]; ok {
|
||
items, ok := v.([]interface{})
|
||
if !ok || len(items) == 0 {
|
||
return "", nil, startIdx, errors.New("and must be a non-empty array")
|
||
}
|
||
var parts []string
|
||
var params []interface{}
|
||
idx := startIdx
|
||
for _, item := range items {
|
||
m, ok := item.(map[string]interface{})
|
||
if !ok {
|
||
return "", nil, startIdx, errors.New("and items must be objects")
|
||
}
|
||
sql, p, next, err := buildFilterSQL(m, idx)
|
||
if err != nil {
|
||
return "", nil, startIdx, err
|
||
}
|
||
parts = append(parts, "("+sql+")")
|
||
params = append(params, p...)
|
||
idx = next
|
||
}
|
||
return strings.Join(parts, " AND "), params, idx, nil
|
||
}
|
||
if v, ok := filter["or"]; ok {
|
||
items, ok := v.([]interface{})
|
||
if !ok || len(items) == 0 {
|
||
return "", nil, startIdx, errors.New("or must be a non-empty array")
|
||
}
|
||
var parts []string
|
||
var params []interface{}
|
||
idx := startIdx
|
||
for _, item := range items {
|
||
m, ok := item.(map[string]interface{})
|
||
if !ok {
|
||
return "", nil, startIdx, errors.New("or items must be objects")
|
||
}
|
||
sql, p, next, err := buildFilterSQL(m, idx)
|
||
if err != nil {
|
||
return "", nil, startIdx, err
|
||
}
|
||
parts = append(parts, "("+sql+")")
|
||
params = append(params, p...)
|
||
idx = next
|
||
}
|
||
return strings.Join(parts, " OR "), params, idx, nil
|
||
}
|
||
|
||
if v, ok := filter["eq"]; ok {
|
||
return buildSimpleOp("eq", v, startIdx)
|
||
}
|
||
if v, ok := filter["in"]; ok {
|
||
return buildSimpleOp("in", v, startIdx)
|
||
}
|
||
if v, ok := filter["contains"]; ok {
|
||
return buildSimpleOp("contains", v, startIdx)
|
||
}
|
||
if v, ok := filter["exists"]; ok {
|
||
return buildSimpleOp("exists", v, startIdx)
|
||
}
|
||
if v, ok := filter["lt"]; ok {
|
||
return buildSimpleOp("lt", v, startIdx)
|
||
}
|
||
if v, ok := filter["lte"]; ok {
|
||
return buildSimpleOp("lte", v, startIdx)
|
||
}
|
||
if v, ok := filter["gt"]; ok {
|
||
return buildSimpleOp("gt", v, startIdx)
|
||
}
|
||
if v, ok := filter["gte"]; ok {
|
||
return buildSimpleOp("gte", v, startIdx)
|
||
}
|
||
return "", nil, startIdx, errors.New("unsupported filter operator")
|
||
}
|
||
|
||
func buildSimpleOp(op string, v interface{}, startIdx int) (string, []interface{}, int, error) {
|
||
obj, ok := v.(map[string]interface{})
|
||
if !ok || len(obj) != 1 {
|
||
return "", nil, startIdx, fmt.Errorf("%s must be an object with a single field", op)
|
||
}
|
||
var path string
|
||
var value interface{}
|
||
for k, val := range obj {
|
||
path = k
|
||
value = val
|
||
}
|
||
keys, err := parseMetadataPath(path)
|
||
if err != nil {
|
||
return "", nil, startIdx, err
|
||
}
|
||
pathLiteral := strings.Join(keys, ",")
|
||
|
||
switch op {
|
||
case "eq":
|
||
sql := fmt.Sprintf("metadata #>> '{%s}' = $%d", pathLiteral, startIdx)
|
||
return sql, []interface{}{fmt.Sprint(value)}, startIdx + 1, nil
|
||
case "in":
|
||
list, ok := value.([]interface{})
|
||
if !ok || len(list) == 0 {
|
||
return "", nil, startIdx, errors.New("in must be a non-empty array")
|
||
}
|
||
values := make([]string, 0, len(list))
|
||
for _, item := range list {
|
||
values = append(values, fmt.Sprint(item))
|
||
}
|
||
sql := fmt.Sprintf("metadata #>> '{%s}' = ANY($%d)", pathLiteral, startIdx)
|
||
return sql, []interface{}{values}, startIdx + 1, nil
|
||
case "contains":
|
||
switch val := value.(type) {
|
||
case string:
|
||
sql := fmt.Sprintf("metadata #>> '{%s}' ILIKE '%%' || $%d || '%%'", pathLiteral, startIdx)
|
||
return sql, []interface{}{val}, startIdx + 1, nil
|
||
case []interface{}:
|
||
if len(val) == 0 {
|
||
return "", nil, startIdx, errors.New("contains array must be non-empty")
|
||
}
|
||
raw, err := json.Marshal(val)
|
||
if err != nil {
|
||
return "", nil, startIdx, err
|
||
}
|
||
sql := fmt.Sprintf("metadata #> '{%s}' @> $%d::jsonb", pathLiteral, startIdx)
|
||
return sql, []interface{}{string(raw)}, startIdx + 1, nil
|
||
case []string:
|
||
if len(val) == 0 {
|
||
return "", nil, startIdx, errors.New("contains array must be non-empty")
|
||
}
|
||
raw, err := json.Marshal(val)
|
||
if err != nil {
|
||
return "", nil, startIdx, err
|
||
}
|
||
sql := fmt.Sprintf("metadata #> '{%s}' @> $%d::jsonb", pathLiteral, startIdx)
|
||
return sql, []interface{}{string(raw)}, startIdx + 1, nil
|
||
default:
|
||
return "", nil, startIdx, errors.New("contains must be string or array")
|
||
}
|
||
case "exists":
|
||
if len(keys) == 1 {
|
||
sql := fmt.Sprintf("metadata ? '%s'", keys[0])
|
||
return sql, nil, startIdx, nil
|
||
}
|
||
parent := strings.Join(keys[:len(keys)-1], ",")
|
||
leaf := keys[len(keys)-1]
|
||
sql := fmt.Sprintf("(metadata #> '{%s}') ? '%s'", parent, leaf)
|
||
return sql, nil, startIdx, nil
|
||
case "lt":
|
||
sql := fmt.Sprintf("metadata #>> '{%s}' < $%d", pathLiteral, startIdx)
|
||
return sql, []interface{}{fmt.Sprint(value)}, startIdx + 1, nil
|
||
case "lte":
|
||
sql := fmt.Sprintf("metadata #>> '{%s}' <= $%d", pathLiteral, startIdx)
|
||
return sql, []interface{}{fmt.Sprint(value)}, startIdx + 1, nil
|
||
case "gt":
|
||
sql := fmt.Sprintf("metadata #>> '{%s}' > $%d", pathLiteral, startIdx)
|
||
return sql, []interface{}{fmt.Sprint(value)}, startIdx + 1, nil
|
||
case "gte":
|
||
sql := fmt.Sprintf("metadata #>> '{%s}' >= $%d", pathLiteral, startIdx)
|
||
return sql, []interface{}{fmt.Sprint(value)}, startIdx + 1, nil
|
||
default:
|
||
return "", nil, startIdx, errors.New("unsupported operator")
|
||
}
|
||
}
|
||
|
||
func parseMetadataPath(path string) ([]string, error) {
|
||
if !strings.HasPrefix(path, "metadata.") {
|
||
return nil, errors.New("filter path must start with metadata.")
|
||
}
|
||
parts := strings.Split(path, ".")[1:]
|
||
if len(parts) == 0 {
|
||
return nil, errors.New("filter path is invalid")
|
||
}
|
||
for _, part := range parts {
|
||
if part == "" || !isSafeIdent(part) {
|
||
return nil, errors.New("filter path contains invalid characters")
|
||
}
|
||
}
|
||
return parts, nil
|
||
}
|
||
|
||
func isSafeIdent(s string) bool {
|
||
for _, r := range s {
|
||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' {
|
||
continue
|
||
}
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
func vectorLiteral(vec []float64) string {
|
||
parts := make([]string, 0, len(vec))
|
||
for _, v := range vec {
|
||
parts = append(parts, fmt.Sprintf("%.8f", v))
|
||
}
|
||
return "[" + strings.Join(parts, ",") + "]"
|
||
}
|
||
|
||
func distanceToSimilarity(distance float64) float64 {
|
||
// pgvector cosine distance is in [0, 2]; convert to similarity in [0, 1]
|
||
sim := 1 - (distance / 2.0)
|
||
if sim < 0 {
|
||
return 0
|
||
}
|
||
if sim > 1 {
|
||
return 1
|
||
}
|
||
return sim
|
||
}
|
||
|
||
type embeddingRequest struct {
|
||
Input string `json:"input"`
|
||
Model string `json:"model"`
|
||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||
}
|
||
|
||
type embeddingResponse struct {
|
||
Data []struct {
|
||
Embedding []float64 `json:"embedding"`
|
||
} `json:"data"`
|
||
}
|
||
|
||
func createEmbedding(ctx context.Context, text string) ([]float64, error) {
|
||
return createEmbeddingWithDim(ctx, text, embeddingDim())
|
||
}
|
||
|
||
func createEmbeddingWithDim(ctx context.Context, text string, dim int) ([]float64, error) {
|
||
provider := strings.ToLower(strings.TrimSpace(os.Getenv("EMBEDDING_PROVIDER")))
|
||
llamaURL := strings.TrimSpace(os.Getenv("LLAMA_CPP_URL"))
|
||
if provider == "" && llamaURL != "" {
|
||
provider = "llamacpp"
|
||
}
|
||
if provider == "" {
|
||
provider = "openai"
|
||
}
|
||
|
||
switch provider {
|
||
case "llamacpp":
|
||
return createLlamaCppEmbedding(ctx, text, dim)
|
||
case "openai":
|
||
return createOpenAIEmbedding(ctx, text)
|
||
default:
|
||
return nil, fmt.Errorf("unsupported EMBEDDING_PROVIDER: %s", provider)
|
||
}
|
||
}
|
||
|
||
func createLlamaCppEmbedding(ctx context.Context, text string, dim int) ([]float64, error) {
|
||
baseURL := strings.TrimSpace(os.Getenv("LLAMA_CPP_URL"))
|
||
if baseURL == "" {
|
||
baseURL = "http://127.0.0.1:8092"
|
||
}
|
||
endpoint := strings.TrimRight(baseURL, "/") + "/embedding"
|
||
|
||
payload, err := json.Marshal(map[string]string{
|
||
"content": text,
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(payload)))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
client := &http.Client{Timeout: 15 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||
return nil, fmt.Errorf("llama.cpp embedding request failed with status %d", resp.StatusCode)
|
||
}
|
||
|
||
var out []struct {
|
||
Embedding [][]float64 `json:"embedding"`
|
||
}
|
||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||
return nil, err
|
||
}
|
||
if len(out) == 0 || len(out[0].Embedding) == 0 || len(out[0].Embedding[0]) == 0 {
|
||
return nil, errors.New("llama.cpp embedding response is empty")
|
||
}
|
||
emb := out[0].Embedding[0]
|
||
if len(emb) != dim {
|
||
return nil, fmt.Errorf("embedding dimension mismatch: got %d, expected %d", len(emb), dim)
|
||
}
|
||
return emb, nil
|
||
}
|
||
|
||
func createOpenAIEmbedding(ctx context.Context, text string) ([]float64, error) {
|
||
apiKey := strings.TrimSpace(os.Getenv("OPENAI_API_KEY"))
|
||
if apiKey == "" {
|
||
return nil, errors.New("OPENAI_API_KEY is not set")
|
||
}
|
||
model := strings.TrimSpace(os.Getenv("EMBEDDING_MODEL"))
|
||
if model == "" {
|
||
model = "text-embedding-3-small"
|
||
}
|
||
|
||
reqBody := embeddingRequest{
|
||
Input: text,
|
||
Model: model,
|
||
EncodingFormat: "float",
|
||
}
|
||
payload, err := json.Marshal(reqBody)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.openai.com/v1/embeddings", strings.NewReader(string(payload)))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
|
||
client := &http.Client{Timeout: 15 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||
return nil, fmt.Errorf("embedding request failed with status %d", resp.StatusCode)
|
||
}
|
||
|
||
var out embeddingResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||
return nil, err
|
||
}
|
||
if len(out.Data) == 0 || len(out.Data[0].Embedding) == 0 {
|
||
return nil, errors.New("embedding response is empty")
|
||
}
|
||
return out.Data[0].Embedding, nil
|
||
}
|
||
|
||
func embeddingDim() int {
|
||
raw := strings.TrimSpace(os.Getenv("EMBEDDING_DIM"))
|
||
if raw == "" {
|
||
return 1024
|
||
}
|
||
n, err := strconv.Atoi(raw)
|
||
if err != nil || n <= 0 {
|
||
return 1024
|
||
}
|
||
return n
|
||
}
|
||
|
||
func parseTimeParam(raw string) (*time.Time, error) {
|
||
if strings.TrimSpace(raw) == "" {
|
||
return nil, nil
|
||
}
|
||
t, err := time.Parse(time.RFC3339, raw)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &t, nil
|
||
}
|
||
|
||
func parseLimit(raw string, def, max int) (int, error) {
|
||
if strings.TrimSpace(raw) == "" {
|
||
return def, nil
|
||
}
|
||
n, err := strconv.Atoi(raw)
|
||
if err != nil || n <= 0 {
|
||
return 0, errors.New("limit must be a positive integer")
|
||
}
|
||
if n > max {
|
||
return max, nil
|
||
}
|
||
return n, nil
|
||
}
|
||
|
||
func isLocalhostRequest(r *http.Request) bool {
|
||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||
if err != nil {
|
||
host = r.RemoteAddr
|
||
}
|
||
ip := net.ParseIP(host)
|
||
if ip == nil {
|
||
return false
|
||
}
|
||
return ip.IsLoopback()
|
||
}
|
||
|
||
func hasMetadataPermissions(meta map[string]interface{}) bool {
|
||
if meta == nil {
|
||
return false
|
||
}
|
||
_, ok := meta["permissions"]
|
||
return ok
|
||
}
|
||
|
||
func filterHasMetadataPermissions(filter map[string]interface{}) bool {
|
||
return valueHasMetadataPermissions(filter)
|
||
}
|
||
|
||
func valueHasMetadataPermissions(value interface{}) bool {
|
||
switch v := value.(type) {
|
||
case map[string]interface{}:
|
||
for k, child := range v {
|
||
if strings.HasPrefix(k, "metadata.permissions") {
|
||
return true
|
||
}
|
||
if valueHasMetadataPermissions(child) {
|
||
return true
|
||
}
|
||
}
|
||
case []interface{}:
|
||
for _, item := range v {
|
||
if valueHasMetadataPermissions(item) {
|
||
return true
|
||
}
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
func fetchEmbeddingDim(ctx context.Context, db *pgxpool.Pool) (int, error) {
|
||
var typeStr string
|
||
err := db.QueryRow(ctx, `
|
||
SELECT format_type(a.atttypid, a.atttypmod)
|
||
FROM pg_attribute a
|
||
JOIN pg_class c ON a.attrelid = c.oid
|
||
JOIN pg_namespace n ON n.oid = c.relnamespace
|
||
WHERE n.nspname = 'public'
|
||
AND c.relname = 'kb_doc_chunks'
|
||
AND a.attname = 'embedding'
|
||
AND a.attnum > 0
|
||
AND NOT a.attisdropped
|
||
`).Scan(&typeStr)
|
||
if err != nil {
|
||
if errors.Is(err, pgx.ErrNoRows) {
|
||
return 0, errors.New("kb_doc_chunks.embedding not found")
|
||
}
|
||
return 0, err
|
||
}
|
||
dim, err := parseVectorTypeDim(typeStr)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return dim, nil
|
||
}
|
||
|
||
func verifyEmbeddingDim(db *pgxpool.Pool, expected int) error {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
actual, err := fetchEmbeddingDim(ctx, db)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if actual != expected {
|
||
return fmt.Errorf("embedding dimension mismatch: EMBEDDING_DIM=%d db=%d", expected, actual)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func parseVectorTypeDim(typeStr string) (int, error) {
|
||
s := strings.TrimSpace(typeStr)
|
||
if !strings.HasPrefix(s, "vector(") || !strings.HasSuffix(s, ")") {
|
||
return 0, fmt.Errorf("invalid embedding type: %s", s)
|
||
}
|
||
inner := strings.TrimSuffix(strings.TrimPrefix(s, "vector("), ")")
|
||
dim, err := strconv.Atoi(inner)
|
||
if err != nil || dim <= 0 {
|
||
return 0, fmt.Errorf("invalid embedding type: %s", s)
|
||
}
|
||
return dim, nil
|
||
}
|
||
|
||
func logConnectionInfo(db *pgxpool.Pool) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
var user string
|
||
var dbname string
|
||
if err := db.QueryRow(ctx, "SELECT current_user, current_database()").Scan(&user, &dbname); err != nil {
|
||
log.Printf("kb.search db user check failed: %v", err)
|
||
return
|
||
}
|
||
log.Printf("kb.search db user=%s db=%s", user, dbname)
|
||
}
|
||
|
||
func splitIntoChunks(text string, size int, overlap int) []string {
|
||
if size <= 0 {
|
||
return []string{text}
|
||
}
|
||
text = strings.ReplaceAll(text, "\r\n", "\n")
|
||
paras := splitParagraphs(text)
|
||
var chunks []string
|
||
var buf []rune
|
||
flush := func() {
|
||
if len(buf) == 0 {
|
||
return
|
||
}
|
||
chunks = append(chunks, string(buf))
|
||
if overlap > 0 && len(buf) > overlap {
|
||
buf = append([]rune{}, buf[len(buf)-overlap:]...)
|
||
} else {
|
||
buf = buf[:0]
|
||
}
|
||
}
|
||
for _, p := range paras {
|
||
rp := []rune(p)
|
||
if len(rp) > size {
|
||
flush()
|
||
sub := splitByRunes(p, size, overlap)
|
||
chunks = append(chunks, sub...)
|
||
buf = buf[:0]
|
||
continue
|
||
}
|
||
if len(buf)+len(rp) > size {
|
||
flush()
|
||
}
|
||
if len(buf) > 0 {
|
||
buf = append(buf, '\n', '\n')
|
||
}
|
||
buf = append(buf, rp...)
|
||
}
|
||
flush()
|
||
if len(chunks) == 0 {
|
||
return []string{""}
|
||
}
|
||
return chunks
|
||
}
|
||
|
||
func splitParagraphs(text string) []string {
|
||
lines := strings.Split(text, "\n")
|
||
var paras []string
|
||
var cur []string
|
||
for _, line := range lines {
|
||
if strings.TrimSpace(line) == "" {
|
||
if len(cur) > 0 {
|
||
paras = append(paras, strings.Join(cur, "\n"))
|
||
cur = cur[:0]
|
||
}
|
||
continue
|
||
}
|
||
cur = append(cur, line)
|
||
}
|
||
if len(cur) > 0 {
|
||
paras = append(paras, strings.Join(cur, "\n"))
|
||
}
|
||
return paras
|
||
}
|
||
|
||
func splitByRunes(text string, size int, overlap int) []string {
|
||
r := []rune(text)
|
||
if len(r) <= size {
|
||
return []string{text}
|
||
}
|
||
if overlap < 0 {
|
||
overlap = 0
|
||
}
|
||
if overlap >= size {
|
||
overlap = size / 4
|
||
}
|
||
var out []string
|
||
for i := 0; i < len(r); {
|
||
end := i + size
|
||
if end > len(r) {
|
||
end = len(r)
|
||
}
|
||
out = append(out, string(r[i:end]))
|
||
if end == len(r) {
|
||
break
|
||
}
|
||
i = end - overlap
|
||
if i < 0 {
|
||
i = 0
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
func validateCollection(name string) error {
|
||
if name == "" || len(name) > 50 {
|
||
return errors.New("invalid length")
|
||
}
|
||
for _, r := range name {
|
||
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' {
|
||
continue
|
||
}
|
||
return errors.New("invalid character")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func initStore() (APIKeyStore, *pgxpool.Pool, func(), error) {
|
||
dbURL := os.Getenv("DATABASE_URL")
|
||
if strings.TrimSpace(dbURL) == "" {
|
||
log.Println("warning: DATABASE_URL not set; using in-memory api key store")
|
||
return NewMemoryKeyStore(), nil, func() {}, nil
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
pool, err := pgxpool.New(ctx, dbURL)
|
||
if err != nil {
|
||
return nil, nil, func() {}, err
|
||
}
|
||
if err := pool.Ping(ctx); err != nil {
|
||
pool.Close()
|
||
return nil, nil, func() {}, err
|
||
}
|
||
store := NewPostgresKeyStore(pool)
|
||
if err := store.EnsureSchema(ctx); err != nil {
|
||
pool.Close()
|
||
return nil, nil, func() {}, err
|
||
}
|
||
return store, pool, pool.Close, nil
|
||
}
|
||
|
||
func main() {
|
||
adminKey := os.Getenv("ADMIN_API_KEY")
|
||
if adminKey == "" {
|
||
log.Println("warning: ADMIN_API_KEY not set; admin endpoints will return 503")
|
||
}
|
||
|
||
store, db, closeStore, err := initStore()
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
defer closeStore()
|
||
if db != nil {
|
||
if err := verifyEmbeddingDim(db, embeddingDim()); err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
}
|
||
mux := http.NewServeMux()
|
||
|
||
mux.HandleFunc("/health", handleHealth)
|
||
|
||
mux.Handle("/db/query", withAPIKeyAuth(store, http.HandlerFunc(handleNotImplemented)))
|
||
mux.Handle("/kb/upsert", withAPIKeyAuth(store, handleKbUpsert(db)))
|
||
mux.Handle("/kb/search", withAPIKeyAuth(store, handleKbSearch(db)))
|
||
mux.Handle("/kb/delete", withAPIKeyAuth(store, handleKbDelete(db)))
|
||
mux.Handle("/collections", withAPIKeyAuth(store, handleCollections(db)))
|
||
|
||
adminHandler := withAdminAuth(adminKey, http.HandlerFunc(handleAdminApiKeys(store)))
|
||
mux.Handle("/admin/api-keys", adminHandler)
|
||
mux.Handle("/admin/api-keys/", adminHandler)
|
||
mux.Handle("/admin/collections", withAdminAuth(adminKey, handleAdminCollections(db)))
|
||
mux.Handle("/admin/audit/permissions", withAdminAuth(adminKey, handleAdminPermissionsAudit(db)))
|
||
mux.Handle("/admin", withAdminUIAuth(adminKey, http.HandlerFunc(handleAdminUI)))
|
||
|
||
port := os.Getenv("PORT")
|
||
if port == "" {
|
||
port = "8080"
|
||
}
|
||
addr := ":" + port
|
||
log.Printf("pgvecter API listening on %s", addr)
|
||
if err := http.ListenAndServe(addr, mux); err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
}
|