diff --git a/server/src/index.ts b/server/src/index.ts index 92e433a..962a711 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -2,6 +2,7 @@ import "dotenv/config"; import express from 'express'; import multer from 'multer'; import cors from 'cors'; +import path from 'path'; import { query, initializeTables } from './db'; import { parse } from 'csv-parse'; import { createReadStream } from 'fs'; @@ -25,7 +26,6 @@ function guessSqlType(value: any): string { function normalizeColumnName(column: string): string { const reservedKeywords = ['user', 'group', 'order', 'select', 'where', 'from', 'table', 'column']; -// console.log(column); let normalized = column.trim() .toLowerCase() .replace(/[^a-zA-Z0-9]/g, '_') @@ -40,6 +40,20 @@ function normalizeColumnName(column: string): string { return normalized; } +function validateTableName(tableName: string): boolean { + return /^[a-zA-Z0-9_]+$/.test(tableName); +} + +function validateColumnNames(columns: string[]): boolean { + return columns.every(column => /^[a-zA-Z0-9_]+$/.test(column)); +} + +function validateFilePath(filePath: string, uploadDir: string): boolean { + const resolvedUploadDir = path.resolve(uploadDir); + const resolvedFilePath = path.resolve(filePath); + return resolvedFilePath.startsWith(resolvedUploadDir); +} + async function startServer() { // Initialize database tables await initializeTables(); @@ -59,65 +73,107 @@ async function startServer() { return res.status(400).json({ error: 'Table name is required' }); } + // Validate table name to prevent SQL injection + if (!validateTableName(tableName)) { + return res.status(400).json({ error: 'Invalid table name' }); + } + + // Validate file path to prevent path traversal + const uploadDir = 'uploads'; + if (!validateFilePath(req.file.path, uploadDir)) { + return res.status(400).json({ error: 'Invalid file path' }); + } + + // Single pass CSV processing with sampling and batch insertion const csvStream = createReadStream(req.file.path); const parser = parse({ columns: true, skip_empty_lines: true }); - // Collect first 10 rows to analyze column types const sampleRows: any[] = []; const columnTypes = new Map(); + const allRecords: any[] = []; + let isInitialized = false; + let columns: string[] = []; for await (const record of csvStream.pipe(parser)) { - sampleRows.push(record); - if (sampleRows.length === 10) break; + allRecords.push(record); + + // Collect first 10 rows for sampling + if (sampleRows.length < 10) { + sampleRows.push(record); + } + + // Initialize table after collecting first sample + if (sampleRows.length === 1 && !isInitialized) { + // Determine column types from sample data + columns = Object.keys(sampleRows[0]).map(normalizeColumnName); + + // Validate column names to prevent SQL injection + if (!validateColumnNames(columns)) { + return res.status(400).json({ error: 'Invalid column names detected' }); + } + + columns.forEach((column, index) => { + const originalColumn = Object.keys(sampleRows[0])[index]; + const value = sampleRows[0][originalColumn]; + columnTypes.set(column, guessSqlType(value)); + }); + + // Drop existing table if it exists + await query(`DROP TABLE IF EXISTS "${tableName}"`); + + // Create new table with quoted identifiers + const createTableSQL = ` + CREATE TABLE "${tableName}" ( + ${columns.map(column => `"${column}" ${columnTypes.get(column)}`).join(',\n')} + ) + `; + console.log(createTableSQL); + await query(createTableSQL); + isInitialized = true; + } } - if (sampleRows.length === 0) { + if (allRecords.length === 0) { return res.status(400).json({ error: 'CSV file is empty' }); } - // Determine column types from sample data - const columns = Object.keys(sampleRows[0]).map(normalizeColumnName); - columns.forEach((column, index) => { - const originalColumn = Object.keys(sampleRows[0])[index]; - const values = sampleRows.map(row => row[originalColumn]).filter(v => v !== null && v !== ''); - columnTypes.set(column, guessSqlType(values[0])); - }); - - // Drop existing table if it exists - await query(`DROP TABLE IF EXISTS ${tableName}`); - - // Create new table - const createTableSQL = ` - CREATE TABLE ${tableName} ( - ${columns.map(column => `${column} ${columnTypes.get(column)}`).join(',\n')} - ) - `; - console.log(createTableSQL); - await query(createTableSQL); - - // Reset stream for full import - const insertStream = createReadStream(req.file.path); - const insertParser = insertStream.pipe(parse({ - columns: true, - skip_empty_lines: true - })); - - // Insert all records - for await (const record of insertParser) { - const originalColumns = Object.keys(record); + // Batch insert records + const batchSize = 100; + const originalColumns = Object.keys(allRecords[0]); + + for (let i = 0; i < allRecords.length; i += batchSize) { + const batch = allRecords.slice(i, i + batchSize); + const values: any[] = []; + const placeholders: string[] = []; + + batch.forEach((record, batchIndex) => { + const recordValues = originalColumns.map(c => record[c]); + values.push(...recordValues); + const recordPlaceholders = columns.map((_, colIndex) => + `$${batchIndex * columns.length + colIndex + 1}` + ).join(', '); + placeholders.push(`(${recordPlaceholders})`); + }); + const insertSQL = ` - INSERT INTO ${tableName} (${columns.map(c => `"${c}"`).join(', ')}) - VALUES (${columns.map((_, i) => `$${i + 1}`).join(', ')}) + INSERT INTO "${tableName}" (${columns.map(c => `"${c}"`).join(', ')}) + VALUES ${placeholders.join(', ')} `; - await query(insertSQL, originalColumns.map(c => record[c])); + + await query(insertSQL, values); } // After successful upload, analyze the table and store the results const analysis = await analyzeTable(tableName); + // Validate table name again before storing in schema + if (!validateTableName(tableName)) { + return res.status(400).json({ error: 'Invalid table name' }); + } + // Store the analysis in TABLE_SCHEMA await query( `INSERT INTO TABLE_SCHEMA (table_name, analysis) @@ -132,6 +188,7 @@ async function startServer() { res.json({ message: 'CSV data successfully imported to database', tableName, + recordCount: allRecords.length, columnCount: columns.length, columnTypes: Object.fromEntries(columnTypes), analysis @@ -146,6 +203,11 @@ async function startServer() { try { const { message } = req.body; + // Validate message parameter + if (typeof message !== 'string' || message.trim() === '') { + return res.status(400).json({ error: 'Invalid query message' }); + } + // Process the query using our new function const result = await processQuery(message);