Skip to content

Commit

Permalink
Merge pull request #182 from square/alec/fixup_initialization
Browse files Browse the repository at this point in the history
Fix up initialization so there are no race conditions on the symbol t…
  • Loading branch information
JakeWharton committed Mar 7, 2016
2 parents df979a7 + 749f387 commit 112c713
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ internal class Resolver(
if (selectStmt.K_WITH() != null) {
resolver = Resolver(selectStmt.common_table_expression()
.fold(symbolTable, { symbolTable, commonTable ->
symbolTable.merge(SymbolTable(commonTable), commonTable)
symbolTable + SymbolTable(commonTable, commonTable)
}), scopedValues)
} else {
resolver = this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ import com.squareup.sqldelight.SqliteParser
import com.squareup.sqldelight.SqlitePluginException
import java.util.LinkedHashMap

class SymbolTable private constructor(
internal val tables: Map<String, SqliteParser.Create_table_stmtContext>,
internal val views: Map<String, SqliteParser.Create_view_stmtContext>,
internal val commonTables: Map<String, SqliteParser.Common_table_expressionContext>,
private val tableTags: Map<Any, List<String>>,
private val viewTags: Map<Any, List<String>>
class SymbolTable constructor(
internal val tables: Map<String, SqliteParser.Create_table_stmtContext> = emptyMap(),
internal val views: Map<String, SqliteParser.Create_view_stmtContext> = emptyMap(),
internal val commonTables: Map<String, SqliteParser.Common_table_expressionContext> = emptyMap(),
private val tableTags: Map<Any, List<String>> = emptyMap(),
private val viewTags: Map<Any, List<String>> = emptyMap(),
private val tag: Any? = null
) {
constructor(
parsed: SqliteParser.ParseContext
parsed: SqliteParser.ParseContext,
tag: Any
) : this(
if (parsed.sql_stmt_list().create_table_stmt() != null) {
linkedMapOf(parsed.sql_stmt_list().create_table_stmt().table_name().text
Expand All @@ -40,26 +42,21 @@ class SymbolTable private constructor(
.filterNotNull()
.map { it.view_name().text to it }
.toTypedArray()),
emptyMap(),
emptyMap(),
emptyMap()
tag = tag
)

constructor(
commonTable: SqliteParser.Common_table_expressionContext
commonTable: SqliteParser.Common_table_expressionContext,
tag: Any
) : this(
emptyMap(),
emptyMap(),
mapOf(commonTable.table_name().text to commonTable),
emptyMap(),
emptyMap()
commonTables = mapOf(commonTable.table_name().text to commonTable),
tag = tag
)

constructor() : this(emptyMap(), emptyMap(), emptyMap(), emptyMap(), emptyMap())

fun merge(other: SymbolTable, otherTag: Any): SymbolTable {
operator fun plus(other: SymbolTable): SymbolTable {
if (other.tag == null) throw IllegalStateException("Symbol tables being added must have a tag")
val tables = LinkedHashMap(this.tables)
tableTags.filter({ it.key == otherTag }).flatMap({ it.value }).forEach { tables.remove(it) }
tableTags.filter({ it.key == other.tag }).flatMap({ it.value }).forEach { tables.remove(it) }
tables.keys.intersect(other.tables.keys).forEach {
throw SqlitePluginException(other.tables[it]!!.table_name(),
"Table already defined with name $it")
Expand All @@ -74,7 +71,7 @@ class SymbolTable private constructor(
}

val views = LinkedHashMap(this.views)
viewTags.filter({ it.key == otherTag }).flatMap({ it.value }).forEach { views.remove(it) }
viewTags.filter({ it.key == other.tag }).flatMap({ it.value }).forEach { views.remove(it) }
views.keys.intersect(other.tables.keys).forEach {
throw SqlitePluginException(other.tables[it]!!.table_name(),
"View already defined with name $it")
Expand All @@ -92,8 +89,8 @@ class SymbolTable private constructor(
tables + other.tables,
views + other.views,
this.commonTables + other.commonTables,
this.tableTags + (otherTag to other.tables.map { it.key }),
this.viewTags + (otherTag to other.views.map { it.key })
this.tableTags + (other.tag to other.tables.map { it.key }),
this.viewTags + (other.tag to other.views.map { it.key })
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import java.io.FileInputStream

class ResolverTests {
private val parsed = parse(File("src/test/data/ResolverTestData.sq"))
private val symbolTable = SymbolTable(parsed)
private val symbolTable = SymbolTable(parsed, parsed)
private val resolver = Resolver(symbolTable)

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ open class SqlDelightTask : SourceTask() {
getInputs().files.forEach { file ->
file.parseThen { parsed ->
try {
symbolTable = symbolTable.merge(SymbolTable(parsed), file.name)
symbolTable += SymbolTable(parsed, file.name)
} catch (e: SqlitePluginException) {
throw SqlitePluginException(e.originatingElement,
Status.Failure(e.originatingElement, e.message).message(file))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ public interface HockeyPlayerModel {
+ "FROM temp_table2\n"
+ "JOIN temp_table";

String IS_NOT_EXPR = ""
+ "SELECT *\n"
+ "FROM hockey_player\n"
+ "WHERE _id IS NOT 2";

long _id();

String first_name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,9 @@ WITH temp_table AS (
)
SELECT *
FROM temp_table2
JOIN temp_table;
JOIN temp_table;

is_not_expr:
SELECT *
FROM hockey_player
WHERE _id IS NOT 2;
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,35 @@
package com.squareup.sqldelight

import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.command.WriteCommandAction
import com.intellij.openapi.project.Project
import com.intellij.openapi.roots.ProjectRootManager
import com.intellij.openapi.startup.StartupActivity
import com.intellij.util.messages.Topic
import com.intellij.psi.PsiManager
import com.squareup.sqldelight.lang.SqlDelightFileViewProvider
import com.squareup.sqldelight.lang.SqliteContentIterator
import com.squareup.sqldelight.lang.SqliteFile
import com.squareup.sqldelight.types.SymbolTable

class SqlDelightStartupActivity : StartupActivity {
interface SqlDelightStartupListener {
fun startupCompleted(project: Project)
}

override fun runActivity(project: Project) {
ApplicationManager.getApplication().messageBus.syncPublisher(TOPIC).startupCompleted(project)
}

companion object {
val TOPIC = Topic.create("SqlDelight plugin completed startup",
SqlDelightStartupListener::class.java)
var files = arrayListOf<SqliteFile>()
ProjectRootManager.getInstance(project).fileIndex
.iterateContent(SqliteContentIterator(PsiManager.getInstance(project)) { file ->
files.add(file)
true
})
files.forEach { file ->
file.parseThen { parsed ->
SqlDelightFileViewProvider.symbolTable += SymbolTable(parsed, file.virtualFile)
}
}
files.forEach { file ->
ApplicationManager.getApplication().executeOnPooledThread {
WriteCommandAction.runWriteCommandAction(project, {
(file.viewProvider as SqlDelightFileViewProvider).generateJavaInterface()
})
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,19 @@
package com.squareup.sqldelight.lang

import com.intellij.lang.Language
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.command.WriteCommandAction
import com.intellij.openapi.module.ModuleUtil
import com.intellij.openapi.project.Project
import com.intellij.openapi.vfs.LocalFileSystem
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.psi.FileViewProvider
import com.intellij.psi.FileViewProviderFactory
import com.intellij.psi.PsiDocumentManager
import com.intellij.psi.PsiManager
import com.intellij.psi.SingleRootFileViewProvider
import com.squareup.sqldelight.SqlDelightStartupActivity
import com.squareup.sqldelight.SqliteCompiler
import com.squareup.sqldelight.SqliteLexer
import com.squareup.sqldelight.SqliteParser
import com.squareup.sqldelight.SqlitePluginException
import com.squareup.sqldelight.Status
import com.squareup.sqldelight.model.relativePath
import com.squareup.sqldelight.types.SymbolTable
import com.squareup.sqldelight.validation.SqlDelightValidator
import org.antlr.v4.runtime.ANTLRInputStream
import org.antlr.v4.runtime.BaseErrorListener
import org.antlr.v4.runtime.CommonTokenStream
import org.antlr.v4.runtime.RecognitionException
import org.antlr.v4.runtime.Recognizer
import java.io.File

class SqlDelightFileViewProviderFactory : FileViewProviderFactory {
Expand All @@ -59,36 +47,14 @@ internal class SqlDelightFileViewProvider(virtualFile: VirtualFile, language: La
getPsiInner(SqliteLanguage.INSTANCE) as SqliteFile
}

init {
val connection = ApplicationManager.getApplication().messageBus.connect()


ApplicationManager.getApplication().runReadAction {
file.parseThen { parsed ->
symbolTable = symbolTable.merge(SymbolTable(parsed), virtualFile)
}
}

connection.subscribe(SqlDelightStartupActivity.TOPIC,
object : SqlDelightStartupActivity.SqlDelightStartupListener {
override fun startupCompleted(project: Project) {
if (project != file.project) return
ApplicationManager.getApplication().executeOnPooledThread {
WriteCommandAction.runWriteCommandAction(project, { generateJavaInterface() })
}
connection.disconnect()
}
})
}

override fun contentsSynchronized() {
super.contentsSynchronized()
documentManager.performWhenAllCommitted { generateJavaInterface() }
}

private fun generateJavaInterface() {
internal fun generateJavaInterface() {
file.parseThen { parsed ->
symbolTable = symbolTable.merge(SymbolTable(parsed), virtualFile)
symbolTable += SymbolTable(parsed, virtualFile)
sqdelightValidator.validate(parsed, symbolTable)

val status = sqliteCompiler.write(
Expand All @@ -115,42 +81,7 @@ internal class SqlDelightFileViewProvider(virtualFile: VirtualFile, language: La
private val sqliteCompiler = SqliteCompiler()
private val sqdelightValidator = SqlDelightValidator()

private var symbolTable = SymbolTable()

private fun SqliteFile.parseThen(operation: (SqliteParser.ParseContext) -> Unit) {
synchronized (sqliteCompiler) {
val errorListener = GeneratingErrorListener()
val lexer = SqliteLexer(ANTLRInputStream(text))
lexer.removeErrorListeners()
lexer.addErrorListener(errorListener)

val parser = SqliteParser(CommonTokenStream(lexer))
parser.removeErrorListeners()
parser.addErrorListener(errorListener)

val parsed = parser.parse()

if (errorListener.hasError) {
// Syntax level errors are handled by the annotator. Don't generate anything.
return
}

try {
operation(parsed)
} catch (e: SqlitePluginException) {
status = Status.Failure(e.originatingElement, e.message)
}
}
}
}
}

private class GeneratingErrorListener : BaseErrorListener() {
internal var hasError = false

override fun syntaxError(recognizer: Recognizer<*, *>?, offendingSymbol: Any?, line: Int,
charPositionInLine: Int, msg: String?, e: RecognitionException?) {
hasError = true
internal var symbolTable = SymbolTable()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ package com.squareup.sqldelight.lang

import com.intellij.openapi.roots.ContentIterator
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.psi.PsiFile
import com.intellij.psi.PsiManager
import com.squareup.sqldelight.SqliteCompiler

class SqliteContentIterator(private val psiManager: PsiManager,
private val processor: (file: PsiFile) -> Boolean) : ContentIterator {
private val processor: (file: SqliteFile) -> Boolean) : ContentIterator {
override fun processFile(fileOrDir: VirtualFile): Boolean {
return fileOrDir.isDirectory || fileOrDir.extension != SqliteCompiler.FILE_EXTENSION ||
processor(psiManager.findFile(fileOrDir) ?: return true)
processor(psiManager.findFile(fileOrDir) as? SqliteFile ?: return true)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,56 @@ package com.squareup.sqldelight.lang
import com.intellij.extapi.psi.PsiFileBase
import com.intellij.psi.FileViewProvider
import com.intellij.psi.PsiFile
import com.squareup.sqldelight.SqliteLexer
import com.squareup.sqldelight.SqlitePluginException
import com.squareup.sqldelight.Status
import org.antlr.v4.runtime.ANTLRInputStream
import org.antlr.v4.runtime.BaseErrorListener
import org.antlr.v4.runtime.CommonTokenStream
import org.antlr.v4.runtime.RecognitionException
import org.antlr.v4.runtime.Recognizer
import kotlin.properties.Delegates

class SqliteFile internal constructor(viewProvider: FileViewProvider)
: PsiFileBase(viewProvider, SqliteLanguage.INSTANCE) {
var generatedFile: PsiFile? = null
var status: Status? = null
var status: Status by Delegates.notNull()

override fun getFileType() = SqliteFileType.INSTANCE
override fun toString() = "SQLite file"

fun parseThen(operation: (com.squareup.sqldelight.SqliteParser.ParseContext) -> Unit) {
synchronized (project) {
val errorListener = GeneratingErrorListener()
val lexer = SqliteLexer(ANTLRInputStream(text))
lexer.removeErrorListeners()
lexer.addErrorListener(errorListener)

val parser = com.squareup.sqldelight.SqliteParser(CommonTokenStream(lexer))
parser.removeErrorListeners()
parser.addErrorListener(errorListener)

val parsed = parser.parse()

if (errorListener.hasError) {
// Syntax level errors are handled by the annotator. Don't generate anything.
return
}

try {
operation(parsed)
} catch (e: SqlitePluginException) {
status = Status.Failure(e.originatingElement, e.message)
}
}
}

private class GeneratingErrorListener : BaseErrorListener() {
internal var hasError = false

override fun syntaxError(recognizer: Recognizer<*, *>?, offendingSymbol: Any?, line: Int,
charPositionInLine: Int, msg: String?, e: RecognitionException?) {
hasError = true
}
}
}

0 comments on commit 112c713

Please sign in to comment.